summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/constants.py13
-rw-r--r--synapse/api/errors.py3
-rw-r--r--synapse/api/room_versions.py67
-rw-r--r--synapse/config/emailconfig.py4
-rw-r--r--synapse/config/experimental.py6
-rw-r--r--synapse/config/logger.py5
-rw-r--r--synapse/event_auth.py115
-rw-r--r--synapse/events/utils.py5
-rw-r--r--synapse/federation/federation_base.py28
-rw-r--r--synapse/federation/federation_client.py104
-rw-r--r--synapse/federation/federation_server.py41
-rw-r--r--synapse/federation/transport/client.py529
-rw-r--r--synapse/federation/transport/server.py13
-rw-r--r--synapse/handlers/event_auth.py85
-rw-r--r--synapse/handlers/federation.py60
-rw-r--r--synapse/handlers/initial_sync.py7
-rw-r--r--synapse/handlers/receipts.py58
-rw-r--r--synapse/handlers/room.py1
-rw-r--r--synapse/handlers/room_member.py175
-rw-r--r--synapse/http/client.py4
-rw-r--r--synapse/http/federation/matrix_federation_agent.py4
-rw-r--r--synapse/http/proxyagent.py184
-rw-r--r--synapse/http/servlet.py258
-rw-r--r--synapse/logging/context.py4
-rw-r--r--synapse/logging/handlers.py88
-rw-r--r--synapse/module_api/__init__.py2
-rw-r--r--synapse/notifier.py5
-rw-r--r--synapse/push/mailer.py18
-rw-r--r--synapse/replication/tcp/client.py7
-rw-r--r--synapse/rest/admin/users.py6
-rw-r--r--synapse/rest/client/v1/room.py34
-rw-r--r--synapse/rest/client/v2_alpha/account.py9
-rw-r--r--synapse/rest/client/v2_alpha/capabilities.py8
-rw-r--r--synapse/rest/client/v2_alpha/keys.py2
-rw-r--r--synapse/rest/client/v2_alpha/read_marker.py14
-rw-r--r--synapse/rest/client/v2_alpha/receipts.py22
-rw-r--r--synapse/rest/client/v2_alpha/relations.py42
-rw-r--r--synapse/rest/client/v2_alpha/sync.py2
-rw-r--r--synapse/rest/client/versions.py2
-rw-r--r--synapse/rest/consent/consent_resource.py2
-rw-r--r--synapse/rest/media/v1/download_resource.py2
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py16
-rw-r--r--synapse/state/__init__.py38
-rw-r--r--synapse/state/v1.py40
-rw-r--r--synapse/state/v2.py11
-rw-r--r--synapse/storage/database.py51
-rw-r--r--synapse/storage/databases/main/__init__.py21
-rw-r--r--synapse/storage/databases/main/devices.py9
-rw-r--r--synapse/storage/databases/main/event_federation.py95
-rw-r--r--synapse/storage/databases/main/events.py91
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py8
-rw-r--r--synapse/storage/databases/main/room.py2
-rw-r--r--synapse/storage/databases/main/state.py50
-rw-r--r--synapse/storage/databases/main/stats.py6
-rw-r--r--synapse/storage/databases/main/transactions.py8
-rw-r--r--synapse/storage/databases/main/user_directory.py66
-rw-r--r--synapse/storage/databases/state/store.py17
-rw-r--r--synapse/storage/persist_events.py4
-rw-r--r--synapse/storage/schema/main/delta/61/01insertion_event_lookups.sql49
-rw-r--r--synapse/storage/state.py4
-rw-r--r--synapse/streams/config.py16
-rw-r--r--synapse/util/async_helpers.py28
-rw-r--r--synapse/util/caches/cached_call.py27
-rw-r--r--synapse/util/caches/deferred_cache.py15
-rw-r--r--synapse/util/caches/descriptors.py2
65 files changed, 2023 insertions, 689 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 8c7ad2a407..a986fdb47a 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -120,6 +120,7 @@ class EventTypes:
     SpaceParent = "m.space.parent"
 
     MSC2716_INSERTION = "org.matrix.msc2716.insertion"
+    MSC2716_CHUNK = "org.matrix.msc2716.chunk"
     MSC2716_MARKER = "org.matrix.msc2716.marker"
 
 
@@ -198,15 +199,13 @@ class EventContentFields:
 
     # Used on normal messages to indicate they were historically imported after the fact
     MSC2716_HISTORICAL = "org.matrix.msc2716.historical"
-    # For "insertion" events
+    # For "insertion" events to indicate what the next chunk ID should be in
+    # order to connect to it
     MSC2716_NEXT_CHUNK_ID = "org.matrix.msc2716.next_chunk_id"
-    # Used on normal message events to indicate where the chunk connects to
+    # Used on "chunk" events to indicate which insertion event it connects to
     MSC2716_CHUNK_ID = "org.matrix.msc2716.chunk_id"
     # For "marker" events
     MSC2716_MARKER_INSERTION = "org.matrix.msc2716.marker.insertion"
-    MSC2716_MARKER_INSERTION_PREV_EVENTS = (
-        "org.matrix.msc2716.marker.insertion_prev_events"
-    )
 
 
 class RoomTypes:
@@ -230,3 +229,7 @@ class HistoryVisibility:
     JOINED = "joined"
     SHARED = "shared"
     WORLD_READABLE = "world_readable"
+
+
+class ReadReceiptEventFields:
+    MSC2285_HIDDEN = "org.matrix.msc2285.hidden"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 054ab14ab6..dc662bca83 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -75,6 +75,9 @@ class Codes:
     INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
     USER_DEACTIVATED = "M_USER_DEACTIVATED"
     BAD_ALIAS = "M_BAD_ALIAS"
+    # For restricted join rules.
+    UNABLE_AUTHORISE_JOIN = "M_UNABLE_TO_AUTHORISE_JOIN"
+    UNABLE_TO_GRANT_JOIN = "M_UNABLE_TO_GRANT_JOIN"
 
 
 class CodeMessageException(RuntimeError):
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index a20abc5a65..bc678efe49 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Dict
+from typing import Callable, Dict, Optional
 
 import attr
 
@@ -73,6 +73,9 @@ class RoomVersion:
     # MSC2403: Allows join_rules to be set to 'knock', changes auth rules to allow sending
     # m.room.membership event with membership 'knock'.
     msc2403_knocking = attr.ib(type=bool)
+    # MSC2716: Adds m.room.power_levels -> content.historical field to control
+    # whether "insertion", "chunk", "marker" events can be sent
+    msc2716_historical = attr.ib(type=bool)
 
 
 class RoomVersions:
@@ -88,6 +91,7 @@ class RoomVersions:
         msc2176_redaction_rules=False,
         msc3083_join_rules=False,
         msc2403_knocking=False,
+        msc2716_historical=False,
     )
     V2 = RoomVersion(
         "2",
@@ -101,6 +105,7 @@ class RoomVersions:
         msc2176_redaction_rules=False,
         msc3083_join_rules=False,
         msc2403_knocking=False,
+        msc2716_historical=False,
     )
     V3 = RoomVersion(
         "3",
@@ -114,6 +119,7 @@ class RoomVersions:
         msc2176_redaction_rules=False,
         msc3083_join_rules=False,
         msc2403_knocking=False,
+        msc2716_historical=False,
     )
     V4 = RoomVersion(
         "4",
@@ -127,6 +133,7 @@ class RoomVersions:
         msc2176_redaction_rules=False,
         msc3083_join_rules=False,
         msc2403_knocking=False,
+        msc2716_historical=False,
     )
     V5 = RoomVersion(
         "5",
@@ -140,6 +147,7 @@ class RoomVersions:
         msc2176_redaction_rules=False,
         msc3083_join_rules=False,
         msc2403_knocking=False,
+        msc2716_historical=False,
     )
     V6 = RoomVersion(
         "6",
@@ -153,6 +161,7 @@ class RoomVersions:
         msc2176_redaction_rules=False,
         msc3083_join_rules=False,
         msc2403_knocking=False,
+        msc2716_historical=False,
     )
     MSC2176 = RoomVersion(
         "org.matrix.msc2176",
@@ -166,9 +175,10 @@ class RoomVersions:
         msc2176_redaction_rules=True,
         msc3083_join_rules=False,
         msc2403_knocking=False,
+        msc2716_historical=False,
     )
     MSC3083 = RoomVersion(
-        "org.matrix.msc3083",
+        "org.matrix.msc3083.v2",
         RoomDisposition.UNSTABLE,
         EventFormatVersions.V3,
         StateResolutionVersions.V2,
@@ -179,6 +189,7 @@ class RoomVersions:
         msc2176_redaction_rules=False,
         msc3083_join_rules=True,
         msc2403_knocking=False,
+        msc2716_historical=False,
     )
     V7 = RoomVersion(
         "7",
@@ -192,6 +203,21 @@ class RoomVersions:
         msc2176_redaction_rules=False,
         msc3083_join_rules=False,
         msc2403_knocking=True,
+        msc2716_historical=False,
+    )
+    MSC2716 = RoomVersion(
+        "org.matrix.msc2716",
+        RoomDisposition.STABLE,
+        EventFormatVersions.V3,
+        StateResolutionVersions.V2,
+        enforce_key_validity=True,
+        special_case_aliases_auth=False,
+        strict_canonicaljson=True,
+        limit_notifications_power_levels=True,
+        msc2176_redaction_rules=False,
+        msc3083_join_rules=False,
+        msc2403_knocking=True,
+        msc2716_historical=True,
     )
 
 
@@ -207,6 +233,41 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
         RoomVersions.MSC2176,
         RoomVersions.MSC3083,
         RoomVersions.V7,
+        RoomVersions.MSC2716,
+    )
+}
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RoomVersionCapability:
+    """An object which describes the unique attributes of a room version."""
+
+    identifier: str  # the identifier for this capability
+    preferred_version: Optional[RoomVersion]
+    support_check_lambda: Callable[[RoomVersion], bool]
+
+
+MSC3244_CAPABILITIES = {
+    cap.identifier: {
+        "preferred": cap.preferred_version.identifier
+        if cap.preferred_version is not None
+        else None,
+        "support": [
+            v.identifier
+            for v in KNOWN_ROOM_VERSIONS.values()
+            if cap.support_check_lambda(v)
+        ],
+    }
+    for cap in (
+        RoomVersionCapability(
+            "knock",
+            RoomVersions.V7,
+            lambda room_version: room_version.msc2403_knocking,
+        ),
+        RoomVersionCapability(
+            "restricted",
+            None,
+            lambda room_version: room_version.msc3083_join_rules,
+        ),
     )
-    # Note that we do not include MSC2043 here unless it is enabled in the config.
 }
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index bcecbfec03..8d8f166e9b 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -39,12 +39,13 @@ DEFAULT_SUBJECTS = {
     "messages_from_person_and_others": "[%(app)s] You have messages on %(app)s from %(person)s and others...",
     "invite_from_person": "[%(app)s] %(person)s has invited you to chat on %(app)s...",
     "invite_from_person_to_room": "[%(app)s] %(person)s has invited you to join the %(room)s room on %(app)s...",
+    "invite_from_person_to_space": "[%(app)s] %(person)s has invited you to join the %(space)s space on %(app)s...",
     "password_reset": "[%(server_name)s] Password reset",
     "email_validation": "[%(server_name)s] Validate your email",
 }
 
 
-@attr.s
+@attr.s(slots=True, frozen=True)
 class EmailSubjectConfig:
     message_from_person_in_room = attr.ib(type=str)
     message_from_person = attr.ib(type=str)
@@ -54,6 +55,7 @@ class EmailSubjectConfig:
     messages_from_person_and_others = attr.ib(type=str)
     invite_from_person = attr.ib(type=str)
     invite_from_person_to_room = attr.ib(type=str)
+    invite_from_person_to_space = attr.ib(type=str)
     password_reset = attr.ib(type=str)
     email_validation = attr.ib(type=str)
 
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index e25ccba9ac..4c60ee8c28 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -32,3 +32,9 @@ class ExperimentalConfig(Config):
 
         # MSC2716 (backfill existing history)
         self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False)
+
+        # MSC2285 (hidden read receipts)
+        self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False)
+
+        # MSC3244 (room version capabilities)
+        self.msc3244_enabled: bool = experimental.get("msc3244_enabled", False)
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index ad4e6e61c3..dcd3ed1dac 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -71,7 +71,7 @@ handlers:
     # will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR
     # logs will still be flushed immediately.
     buffer:
-        class: logging.handlers.MemoryHandler
+        class: synapse.logging.handlers.PeriodicallyFlushingMemoryHandler
         target: file
         # The capacity is the number of log lines that are buffered before
         # being written to disk. Increasing this will lead to better
@@ -79,6 +79,9 @@ handlers:
         # be written to disk.
         capacity: 10
         flushLevel: 30  # Flush for WARNING logs as well
+        # The period of time, in seconds, between forced flushes.
+        # Messages will not be delayed for longer than this time.
+        period: 5
 
     # A handler that writes logs to stderr. Unused by default, but can be used
     # instead of "buffer" and "file" in the logger handlers.
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 137dff2513..0fa7ffc99f 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -106,6 +106,18 @@ def check(
             if not event.signatures.get(event_id_domain):
                 raise AuthError(403, "Event not signed by sending server")
 
+        is_invite_via_allow_rule = (
+            event.type == EventTypes.Member
+            and event.membership == Membership.JOIN
+            and "join_authorised_via_users_server" in event.content
+        )
+        if is_invite_via_allow_rule:
+            authoriser_domain = get_domain_from_id(
+                event.content["join_authorised_via_users_server"]
+            )
+            if not event.signatures.get(authoriser_domain):
+                raise AuthError(403, "Event not signed by authorising server")
+
     # Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules
     #
     # 1. If type is m.room.create:
@@ -177,7 +189,7 @@ def check(
     # https://github.com/vector-im/vector-web/issues/1208 hopefully
     if event.type == EventTypes.ThirdPartyInvite:
         user_level = get_user_power_level(event.user_id, auth_events)
-        invite_level = _get_named_level(auth_events, "invite", 0)
+        invite_level = get_named_level(auth_events, "invite", 0)
 
         if user_level < invite_level:
             raise AuthError(403, "You don't have permission to invite users")
@@ -193,6 +205,13 @@ def check(
     if event.type == EventTypes.Redaction:
         check_redaction(room_version_obj, event, auth_events)
 
+    if (
+        event.type == EventTypes.MSC2716_INSERTION
+        or event.type == EventTypes.MSC2716_CHUNK
+        or event.type == EventTypes.MSC2716_MARKER
+    ):
+        check_historical(room_version_obj, event, auth_events)
+
     logger.debug("Allowing! %s", event)
 
 
@@ -285,8 +304,8 @@ def _is_membership_change_allowed(
     user_level = get_user_power_level(event.user_id, auth_events)
     target_level = get_user_power_level(target_user_id, auth_events)
 
-    # FIXME (erikj): What should we do here as the default?
-    ban_level = _get_named_level(auth_events, "ban", 50)
+    invite_level = get_named_level(auth_events, "invite", 0)
+    ban_level = get_named_level(auth_events, "ban", 50)
 
     logger.debug(
         "_is_membership_change_allowed: %s",
@@ -336,8 +355,6 @@ def _is_membership_change_allowed(
         elif target_in_room:  # the target is already in the room.
             raise AuthError(403, "%s is already in the room." % target_user_id)
         else:
-            invite_level = _get_named_level(auth_events, "invite", 0)
-
             if user_level < invite_level:
                 raise AuthError(403, "You don't have permission to invite users")
     elif Membership.JOIN == membership:
@@ -345,16 +362,41 @@ def _is_membership_change_allowed(
         # * They are not banned.
         # * They are accepting a previously sent invitation.
         # * They are already joined (it's a NOOP).
-        # * The room is public or restricted.
+        # * The room is public.
+        # * The room is restricted and the user meets the allows rules.
         if event.user_id != target_user_id:
             raise AuthError(403, "Cannot force another user to join.")
         elif target_banned:
             raise AuthError(403, "You are banned from this room")
-        elif join_rule == JoinRules.PUBLIC or (
+        elif join_rule == JoinRules.PUBLIC:
+            pass
+        elif (
             room_version.msc3083_join_rules
             and join_rule == JoinRules.MSC3083_RESTRICTED
         ):
-            pass
+            # This is the same as public, but the event must contain a reference
+            # to the server who authorised the join. If the event does not contain
+            # the proper content it is rejected.
+            #
+            # Note that if the caller is in the room or invited, then they do
+            # not need to meet the allow rules.
+            if not caller_in_room and not caller_invited:
+                authorising_user = event.content.get("join_authorised_via_users_server")
+
+                if authorising_user is None:
+                    raise AuthError(403, "Join event is missing authorising user.")
+
+                # The authorising user must be in the room.
+                key = (EventTypes.Member, authorising_user)
+                member_event = auth_events.get(key)
+                _check_joined_room(member_event, authorising_user, event.room_id)
+
+                authorising_user_level = get_user_power_level(
+                    authorising_user, auth_events
+                )
+                if authorising_user_level < invite_level:
+                    raise AuthError(403, "Join event authorised by invalid server.")
+
         elif join_rule == JoinRules.INVITE or (
             room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
         ):
@@ -369,7 +411,7 @@ def _is_membership_change_allowed(
         if target_banned and user_level < ban_level:
             raise AuthError(403, "You cannot unban user %s." % (target_user_id,))
         elif target_user_id != event.user_id:
-            kick_level = _get_named_level(auth_events, "kick", 50)
+            kick_level = get_named_level(auth_events, "kick", 50)
 
             if user_level < kick_level or user_level <= target_level:
                 raise AuthError(403, "You cannot kick user %s." % target_user_id)
@@ -445,7 +487,7 @@ def get_send_level(
 
 
 def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
-    power_levels_event = _get_power_level_event(auth_events)
+    power_levels_event = get_power_level_event(auth_events)
 
     send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
     user_level = get_user_power_level(event.user_id, auth_events)
@@ -485,7 +527,7 @@ def check_redaction(
     """
     user_level = get_user_power_level(event.user_id, auth_events)
 
-    redact_level = _get_named_level(auth_events, "redact", 50)
+    redact_level = get_named_level(auth_events, "redact", 50)
 
     if user_level >= redact_level:
         return False
@@ -504,6 +546,37 @@ def check_redaction(
     raise AuthError(403, "You don't have permission to redact events")
 
 
+def check_historical(
+    room_version_obj: RoomVersion,
+    event: EventBase,
+    auth_events: StateMap[EventBase],
+) -> None:
+    """Check whether the event sender is allowed to send historical related
+    events like "insertion", "chunk", and "marker".
+
+    Returns:
+        None
+
+    Raises:
+        AuthError if the event sender is not allowed to send historical related events
+        ("insertion", "chunk", and "marker").
+    """
+    # Ignore the auth checks in room versions that do not support historical
+    # events
+    if not room_version_obj.msc2716_historical:
+        return
+
+    user_level = get_user_power_level(event.user_id, auth_events)
+
+    historical_level = get_named_level(auth_events, "historical", 100)
+
+    if user_level < historical_level:
+        raise AuthError(
+            403,
+            'You don\'t have permission to send send historical related events ("insertion", "chunk", and "marker")',
+        )
+
+
 def _check_power_levels(
     room_version_obj: RoomVersion,
     event: EventBase,
@@ -600,7 +673,7 @@ def _check_power_levels(
             )
 
 
-def _get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]:
+def get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]:
     return auth_events.get((EventTypes.PowerLevels, ""))
 
 
@@ -616,7 +689,7 @@ def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
     Returns:
         the user's power level in this room.
     """
-    power_level_event = _get_power_level_event(auth_events)
+    power_level_event = get_power_level_event(auth_events)
     if power_level_event:
         level = power_level_event.content.get("users", {}).get(user_id)
         if not level:
@@ -640,8 +713,8 @@ def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
             return 0
 
 
-def _get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int:
-    power_level_event = _get_power_level_event(auth_events)
+def get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int:
+    power_level_event = get_power_level_event(auth_events)
 
     if not power_level_event:
         return default
@@ -728,7 +801,9 @@ def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
     return public_keys
 
 
-def auth_types_for_event(event: Union[EventBase, EventBuilder]) -> Set[Tuple[str, str]]:
+def auth_types_for_event(
+    room_version: RoomVersion, event: Union[EventBase, EventBuilder]
+) -> Set[Tuple[str, str]]:
     """Given an event, return a list of (EventType, StateKey) that may be
     needed to auth the event. The returned list may be a superset of what
     would actually be required depending on the full state of the room.
@@ -760,4 +835,12 @@ def auth_types_for_event(event: Union[EventBase, EventBuilder]) -> Set[Tuple[str
                 )
                 auth_types.add(key)
 
+        if room_version.msc3083_join_rules and membership == Membership.JOIN:
+            if "join_authorised_via_users_server" in event.content:
+                key = (
+                    EventTypes.Member,
+                    event.content["join_authorised_via_users_server"],
+                )
+                auth_types.add(key)
+
     return auth_types
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index ec96999e4e..a0c07f62f4 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -109,6 +109,8 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
         add_fields("creator")
     elif event_type == EventTypes.JoinRules:
         add_fields("join_rule")
+        if room_version.msc3083_join_rules:
+            add_fields("allow")
     elif event_type == EventTypes.PowerLevels:
         add_fields(
             "users",
@@ -124,6 +126,9 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
         if room_version.msc2176_redaction_rules:
             add_fields("invite")
 
+        if room_version.msc2716_historical:
+            add_fields("historical")
+
     elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth:
         add_fields("aliases")
     elif event_type == EventTypes.RoomHistoryVisibility:
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 2bfe6a3d37..024e440ff4 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -178,6 +178,34 @@ async def _check_sigs_on_pdu(
             )
             raise SynapseError(403, errmsg, Codes.FORBIDDEN)
 
+    # If this is a join event for a restricted room it may have been authorised
+    # via a different server from the sending server. Check those signatures.
+    if (
+        room_version.msc3083_join_rules
+        and pdu.type == EventTypes.Member
+        and pdu.membership == Membership.JOIN
+        and "join_authorised_via_users_server" in pdu.content
+    ):
+        authorising_server = get_domain_from_id(
+            pdu.content["join_authorised_via_users_server"]
+        )
+        try:
+            await keyring.verify_event_for_server(
+                authorising_server,
+                pdu,
+                pdu.origin_server_ts if room_version.enforce_key_validity else 0,
+            )
+        except Exception as e:
+            errmsg = (
+                "event id %s: unable to verify signature for authorising server %s: %s"
+                % (
+                    pdu.event_id,
+                    authorising_server,
+                    e,
+                )
+            )
+            raise SynapseError(403, errmsg, Codes.FORBIDDEN)
+
 
 def _is_invite_via_3pid(event: EventBase) -> bool:
     return (
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index c767d30627..b7a10da15a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -19,10 +19,10 @@ import itertools
 import logging
 from typing import (
     TYPE_CHECKING,
-    Any,
     Awaitable,
     Callable,
     Collection,
+    Container,
     Dict,
     Iterable,
     List,
@@ -79,7 +79,15 @@ class InvalidResponseError(RuntimeError):
     we couldn't parse
     """
 
-    pass
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class SendJoinResult:
+    # The event to persist.
+    event: EventBase
+    # A string giving the server the event was sent to.
+    origin: str
+    state: List[EventBase]
+    auth_chain: List[EventBase]
 
 
 class FederationClient(FederationBase):
@@ -506,6 +514,7 @@ class FederationClient(FederationBase):
         description: str,
         destinations: Iterable[str],
         callback: Callable[[str], Awaitable[T]],
+        failover_errcodes: Optional[Container[str]] = None,
         failover_on_unknown_endpoint: bool = False,
     ) -> T:
         """Try an operation on a series of servers, until it succeeds
@@ -526,6 +535,9 @@ class FederationClient(FederationBase):
                 next server tried. Normally the stacktrace is logged but this is
                 suppressed if the exception is an InvalidResponseError.
 
+            failover_errcodes: Error codes (specific to this endpoint) which should
+                cause a failover when received as part of an HTTP 400 error.
+
             failover_on_unknown_endpoint: if True, we will try other servers if it looks
                 like a server doesn't support the endpoint. This is typically useful
                 if the endpoint in question is new or experimental.
@@ -537,6 +549,9 @@ class FederationClient(FederationBase):
             SynapseError if the chosen remote server returns a 300/400 code, or
             no servers were reachable.
         """
+        if failover_errcodes is None:
+            failover_errcodes = ()
+
         for destination in destinations:
             if destination == self.server_name:
                 continue
@@ -551,11 +566,17 @@ class FederationClient(FederationBase):
                 synapse_error = e.to_synapse_error()
                 failover = False
 
-                # Failover on an internal server error, or if the destination
-                # doesn't implemented the endpoint for some reason.
+                # Failover should occur:
+                #
+                # * On internal server errors.
+                # * If the destination responds that it cannot complete the request.
+                # * If the destination doesn't implemented the endpoint for some reason.
                 if 500 <= e.code < 600:
                     failover = True
 
+                elif e.code == 400 and synapse_error.errcode in failover_errcodes:
+                    failover = True
+
                 elif failover_on_unknown_endpoint and self._is_unknown_endpoint(
                     e, synapse_error
                 ):
@@ -671,13 +692,25 @@ class FederationClient(FederationBase):
 
             return destination, ev, room_version
 
+        # MSC3083 defines additional error codes for room joins. Unfortunately
+        # we do not yet know the room version, assume these will only be returned
+        # by valid room versions.
+        failover_errcodes = (
+            (Codes.UNABLE_AUTHORISE_JOIN, Codes.UNABLE_TO_GRANT_JOIN)
+            if membership == Membership.JOIN
+            else None
+        )
+
         return await self._try_destination_list(
-            "make_" + membership, destinations, send_request
+            "make_" + membership,
+            destinations,
+            send_request,
+            failover_errcodes=failover_errcodes,
         )
 
     async def send_join(
         self, destinations: Iterable[str], pdu: EventBase, room_version: RoomVersion
-    ) -> Dict[str, Any]:
+    ) -> SendJoinResult:
         """Sends a join event to one of a list of homeservers.
 
         Doing so will cause the remote server to add the event to the graph,
@@ -691,18 +724,38 @@ class FederationClient(FederationBase):
                 did the make_join)
 
         Returns:
-            a dict with members ``origin`` (a string
-            giving the server the event was sent to, ``state`` (?) and
-            ``auth_chain``.
+            The result of the send join request.
 
         Raises:
             SynapseError: if the chosen remote server returns a 300/400 code, or
                 no servers successfully handle the request.
         """
 
-        async def send_request(destination) -> Dict[str, Any]:
+        async def send_request(destination) -> SendJoinResult:
             response = await self._do_send_join(room_version, destination, pdu)
 
+            # If an event was returned (and expected to be returned):
+            #
+            # * Ensure it has the same event ID (note that the event ID is a hash
+            #   of the event fields for versions which support MSC3083).
+            # * Ensure the signatures are good.
+            #
+            # Otherwise, fallback to the provided event.
+            if room_version.msc3083_join_rules and response.event:
+                event = response.event
+
+                valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
+                    pdu=event,
+                    origin=destination,
+                    outlier=True,
+                    room_version=room_version,
+                )
+
+                if valid_pdu is None or event.event_id != pdu.event_id:
+                    raise InvalidResponseError("Returned an invalid join event")
+            else:
+                event = pdu
+
             state = response.state
             auth_chain = response.auth_events
 
@@ -784,13 +837,32 @@ class FederationClient(FederationBase):
                     % (auth_chain_create_events,)
                 )
 
-            return {
-                "state": signed_state,
-                "auth_chain": signed_auth,
-                "origin": destination,
-            }
+            return SendJoinResult(
+                event=event,
+                state=signed_state,
+                auth_chain=signed_auth,
+                origin=destination,
+            )
 
-        return await self._try_destination_list("send_join", destinations, send_request)
+        # MSC3083 defines additional error codes for room joins.
+        failover_errcodes = None
+        if room_version.msc3083_join_rules:
+            failover_errcodes = (
+                Codes.UNABLE_AUTHORISE_JOIN,
+                Codes.UNABLE_TO_GRANT_JOIN,
+            )
+
+            # If the join is being authorised via allow rules, we need to send
+            # the /send_join back to the same server that was originally used
+            # with /make_join.
+            if "join_authorised_via_users_server" in pdu.content:
+                destinations = [
+                    get_domain_from_id(pdu.content["join_authorised_via_users_server"])
+                ]
+
+        return await self._try_destination_list(
+            "send_join", destinations, send_request, failover_errcodes=failover_errcodes
+        )
 
     async def _do_send_join(
         self, room_version: RoomVersion, destination: str, pdu: EventBase
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 29619aeeb8..2892a11d7d 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -45,6 +45,7 @@ from synapse.api.errors import (
     UnsupportedRoomVersionError,
 )
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.crypto.event_signing import compute_event_signature
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
@@ -64,7 +65,7 @@ from synapse.replication.http.federation import (
     ReplicationGetQueryRestServlet,
 )
 from synapse.storage.databases.main.lock import Lock
-from synapse.types import JsonDict
+from synapse.types import JsonDict, get_domain_from_id
 from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.caches.response_cache import ResponseCache
@@ -586,7 +587,7 @@ class FederationServer(FederationBase):
     async def on_send_join_request(
         self, origin: str, content: JsonDict, room_id: str
     ) -> Dict[str, Any]:
-        context = await self._on_send_membership_event(
+        event, context = await self._on_send_membership_event(
             origin, content, Membership.JOIN, room_id
         )
 
@@ -597,6 +598,7 @@ class FederationServer(FederationBase):
 
         time_now = self._clock.time_msec()
         return {
+            "org.matrix.msc3083.v2.event": event.get_pdu_json(),
             "state": [p.get_pdu_json(time_now) for p in state.values()],
             "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
         }
@@ -681,7 +683,7 @@ class FederationServer(FederationBase):
         Returns:
             The stripped room state.
         """
-        event_context = await self._on_send_membership_event(
+        _, context = await self._on_send_membership_event(
             origin, content, Membership.KNOCK, room_id
         )
 
@@ -690,14 +692,14 @@ class FederationServer(FederationBase):
         # related to the room while the knock request is pending.
         stripped_room_state = (
             await self.store.get_stripped_room_state_from_event_context(
-                event_context, self._room_prejoin_state_types
+                context, self._room_prejoin_state_types
             )
         )
         return {"knock_state_events": stripped_room_state}
 
     async def _on_send_membership_event(
         self, origin: str, content: JsonDict, membership_type: str, room_id: str
-    ) -> EventContext:
+    ) -> Tuple[EventBase, EventContext]:
         """Handle an on_send_{join,leave,knock} request
 
         Does some preliminary validation before passing the request on to the
@@ -712,7 +714,7 @@ class FederationServer(FederationBase):
                 in the event
 
         Returns:
-            The context of the event after inserting it into the room graph.
+            The event and context of the event after inserting it into the room graph.
 
         Raises:
             SynapseError if there is a problem with the request, including things like
@@ -748,6 +750,33 @@ class FederationServer(FederationBase):
 
         logger.debug("_on_send_membership_event: pdu sigs: %s", event.signatures)
 
+        # Sign the event since we're vouching on behalf of the remote server that
+        # the event is valid to be sent into the room. Currently this is only done
+        # if the user is being joined via restricted join rules.
+        if (
+            room_version.msc3083_join_rules
+            and event.membership == Membership.JOIN
+            and "join_authorised_via_users_server" in event.content
+        ):
+            # We can only authorise our own users.
+            authorising_server = get_domain_from_id(
+                event.content["join_authorised_via_users_server"]
+            )
+            if authorising_server != self.server_name:
+                raise SynapseError(
+                    400,
+                    f"Cannot authorise request from resident server: {authorising_server}",
+                )
+
+            event.signatures.update(
+                compute_event_signature(
+                    room_version,
+                    event.get_pdu_json(),
+                    self.hs.hostname,
+                    self.hs.signing_key,
+                )
+            )
+
         event = await self._check_sigs_and_hash(room_version, event)
 
         return await self.handler.on_send_membership_event(origin, event)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 98b1bf77fd..6a8d3ad4fe 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -15,7 +15,7 @@
 
 import logging
 import urllib
-from typing import Any, Dict, List, Optional
+from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
 
 import attr
 import ijson
@@ -29,6 +29,7 @@ from synapse.api.urls import (
     FEDERATION_V2_PREFIX,
 )
 from synapse.events import EventBase, make_event_from_dict
+from synapse.federation.units import Transaction
 from synapse.http.matrixfederationclient import ByteParser
 from synapse.logging.utils import log_function
 from synapse.types import JsonDict
@@ -49,23 +50,25 @@ class TransportLayerClient:
         self.client = hs.get_federation_http_client()
 
     @log_function
-    def get_room_state_ids(self, destination, room_id, event_id):
+    async def get_room_state_ids(
+        self, destination: str, room_id: str, event_id: str
+    ) -> JsonDict:
         """Requests all state for a given room from the given server at the
         given event. Returns the state's event_id's
 
         Args:
-            destination (str): The host name of the remote homeserver we want
+            destination: The host name of the remote homeserver we want
                 to get the state from.
-            context (str): The name of the context we want the state of
-            event_id (str): The event we want the context at.
+            context: The name of the context we want the state of
+            event_id: The event we want the context at.
 
         Returns:
-            Awaitable: Results in a dict received from the remote homeserver.
+            Results in a dict received from the remote homeserver.
         """
         logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id)
 
         path = _create_v1_path("/state_ids/%s", room_id)
-        return self.client.get_json(
+        return await self.client.get_json(
             destination,
             path=path,
             args={"event_id": event_id},
@@ -73,39 +76,43 @@ class TransportLayerClient:
         )
 
     @log_function
-    def get_event(self, destination, event_id, timeout=None):
+    async def get_event(
+        self, destination: str, event_id: str, timeout: Optional[int] = None
+    ) -> JsonDict:
         """Requests the pdu with give id and origin from the given server.
 
         Args:
-            destination (str): The host name of the remote homeserver we want
+            destination: The host name of the remote homeserver we want
                 to get the state from.
-            event_id (str): The id of the event being requested.
-            timeout (int): How long to try (in ms) the destination for before
+            event_id: The id of the event being requested.
+            timeout: How long to try (in ms) the destination for before
                 giving up. None indicates no timeout.
 
         Returns:
-            Awaitable: Results in a dict received from the remote homeserver.
+            Results in a dict received from the remote homeserver.
         """
         logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id)
 
         path = _create_v1_path("/event/%s", event_id)
-        return self.client.get_json(
+        return await self.client.get_json(
             destination, path=path, timeout=timeout, try_trailing_slash_on_400=True
         )
 
     @log_function
-    def backfill(self, destination, room_id, event_tuples, limit):
+    async def backfill(
+        self, destination: str, room_id: str, event_tuples: Iterable[str], limit: int
+    ) -> Optional[JsonDict]:
         """Requests `limit` previous PDUs in a given context before list of
         PDUs.
 
         Args:
-            dest (str)
-            room_id (str)
-            event_tuples (list)
-            limit (int)
+            destination
+            room_id
+            event_tuples
+            limit
 
         Returns:
-            Awaitable: Results in a dict received from the remote homeserver.
+            Results in a dict received from the remote homeserver.
         """
         logger.debug(
             "backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s",
@@ -117,18 +124,22 @@ class TransportLayerClient:
 
         if not event_tuples:
             # TODO: raise?
-            return
+            return None
 
         path = _create_v1_path("/backfill/%s", room_id)
 
         args = {"v": event_tuples, "limit": [str(limit)]}
 
-        return self.client.get_json(
+        return await self.client.get_json(
             destination, path=path, args=args, try_trailing_slash_on_400=True
         )
 
     @log_function
-    async def send_transaction(self, transaction, json_data_callback=None):
+    async def send_transaction(
+        self,
+        transaction: Transaction,
+        json_data_callback: Optional[Callable[[], JsonDict]] = None,
+    ) -> JsonDict:
         """Sends the given Transaction to its destination
 
         Args:
@@ -149,21 +160,21 @@ class TransportLayerClient:
         """
         logger.debug(
             "send_data dest=%s, txid=%s",
-            transaction.destination,
-            transaction.transaction_id,
+            transaction.destination,  # type: ignore
+            transaction.transaction_id,  # type: ignore
         )
 
-        if transaction.destination == self.server_name:
+        if transaction.destination == self.server_name:  # type: ignore
             raise RuntimeError("Transport layer cannot send to itself!")
 
         # FIXME: This is only used by the tests. The actual json sent is
         # generated by the json_data_callback.
         json_data = transaction.get_dict()
 
-        path = _create_v1_path("/send/%s", transaction.transaction_id)
+        path = _create_v1_path("/send/%s", transaction.transaction_id)  # type: ignore
 
-        response = await self.client.put_json(
-            transaction.destination,
+        return await self.client.put_json(
+            transaction.destination,  # type: ignore
             path=path,
             data=json_data,
             json_data_callback=json_data_callback,
@@ -172,8 +183,6 @@ class TransportLayerClient:
             try_trailing_slash_on_400=True,
         )
 
-        return response
-
     @log_function
     async def make_query(
         self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False
@@ -193,8 +202,13 @@ class TransportLayerClient:
 
     @log_function
     async def make_membership_event(
-        self, destination, room_id, user_id, membership, params
-    ):
+        self,
+        destination: str,
+        room_id: str,
+        user_id: str,
+        membership: str,
+        params: Optional[Mapping[str, Union[str, Iterable[str]]]],
+    ) -> JsonDict:
         """Asks a remote server to build and sign us a membership event
 
         Note that this does not append any events to any graphs.
@@ -240,7 +254,7 @@ class TransportLayerClient:
             ignore_backoff = True
             retry_on_dns_fail = True
 
-        content = await self.client.get_json(
+        return await self.client.get_json(
             destination=destination,
             path=path,
             args=params,
@@ -249,20 +263,18 @@ class TransportLayerClient:
             ignore_backoff=ignore_backoff,
         )
 
-        return content
-
     @log_function
     async def send_join_v1(
         self,
-        room_version,
-        destination,
-        room_id,
-        event_id,
-        content,
+        room_version: RoomVersion,
+        destination: str,
+        room_id: str,
+        event_id: str,
+        content: JsonDict,
     ) -> "SendJoinResponse":
         path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
 
-        response = await self.client.put_json(
+        return await self.client.put_json(
             destination=destination,
             path=path,
             data=content,
@@ -270,15 +282,18 @@ class TransportLayerClient:
             max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
         )
 
-        return response
-
     @log_function
     async def send_join_v2(
-        self, room_version, destination, room_id, event_id, content
+        self,
+        room_version: RoomVersion,
+        destination: str,
+        room_id: str,
+        event_id: str,
+        content: JsonDict,
     ) -> "SendJoinResponse":
         path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
 
-        response = await self.client.put_json(
+        return await self.client.put_json(
             destination=destination,
             path=path,
             data=content,
@@ -286,13 +301,13 @@ class TransportLayerClient:
             max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
         )
 
-        return response
-
     @log_function
-    async def send_leave_v1(self, destination, room_id, event_id, content):
+    async def send_leave_v1(
+        self, destination: str, room_id: str, event_id: str, content: JsonDict
+    ) -> Tuple[int, JsonDict]:
         path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
 
-        response = await self.client.put_json(
+        return await self.client.put_json(
             destination=destination,
             path=path,
             data=content,
@@ -303,13 +318,13 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-        return response
-
     @log_function
-    async def send_leave_v2(self, destination, room_id, event_id, content):
+    async def send_leave_v2(
+        self, destination: str, room_id: str, event_id: str, content: JsonDict
+    ) -> JsonDict:
         path = _create_v2_path("/send_leave/%s/%s", room_id, event_id)
 
-        response = await self.client.put_json(
+        return await self.client.put_json(
             destination=destination,
             path=path,
             data=content,
@@ -320,8 +335,6 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-        return response
-
     @log_function
     async def send_knock_v1(
         self,
@@ -357,25 +370,25 @@ class TransportLayerClient:
         )
 
     @log_function
-    async def send_invite_v1(self, destination, room_id, event_id, content):
+    async def send_invite_v1(
+        self, destination: str, room_id: str, event_id: str, content: JsonDict
+    ) -> Tuple[int, JsonDict]:
         path = _create_v1_path("/invite/%s/%s", room_id, event_id)
 
-        response = await self.client.put_json(
+        return await self.client.put_json(
             destination=destination, path=path, data=content, ignore_backoff=True
         )
 
-        return response
-
     @log_function
-    async def send_invite_v2(self, destination, room_id, event_id, content):
+    async def send_invite_v2(
+        self, destination: str, room_id: str, event_id: str, content: JsonDict
+    ) -> JsonDict:
         path = _create_v2_path("/invite/%s/%s", room_id, event_id)
 
-        response = await self.client.put_json(
+        return await self.client.put_json(
             destination=destination, path=path, data=content, ignore_backoff=True
         )
 
-        return response
-
     @log_function
     async def get_public_rooms(
         self,
@@ -385,7 +398,7 @@ class TransportLayerClient:
         search_filter: Optional[Dict] = None,
         include_all_networks: bool = False,
         third_party_instance_id: Optional[str] = None,
-    ):
+    ) -> JsonDict:
         """Get the list of public rooms from a remote homeserver
 
         See synapse.federation.federation_client.FederationClient.get_public_rooms for
@@ -450,25 +463,27 @@ class TransportLayerClient:
         return response
 
     @log_function
-    async def exchange_third_party_invite(self, destination, room_id, event_dict):
+    async def exchange_third_party_invite(
+        self, destination: str, room_id: str, event_dict: JsonDict
+    ) -> JsonDict:
         path = _create_v1_path("/exchange_third_party_invite/%s", room_id)
 
-        response = await self.client.put_json(
+        return await self.client.put_json(
             destination=destination, path=path, data=event_dict
         )
 
-        return response
-
     @log_function
-    async def get_event_auth(self, destination, room_id, event_id):
+    async def get_event_auth(
+        self, destination: str, room_id: str, event_id: str
+    ) -> JsonDict:
         path = _create_v1_path("/event_auth/%s/%s", room_id, event_id)
 
-        content = await self.client.get_json(destination=destination, path=path)
-
-        return content
+        return await self.client.get_json(destination=destination, path=path)
 
     @log_function
-    async def query_client_keys(self, destination, query_content, timeout):
+    async def query_client_keys(
+        self, destination: str, query_content: JsonDict, timeout: int
+    ) -> JsonDict:
         """Query the device keys for a list of user ids hosted on a remote
         server.
 
@@ -496,20 +511,21 @@ class TransportLayerClient:
             }
 
         Args:
-            destination(str): The server to query.
-            query_content(dict): The user ids to query.
+            destination: The server to query.
+            query_content: The user ids to query.
         Returns:
             A dict containing device and cross-signing keys.
         """
         path = _create_v1_path("/user/keys/query")
 
-        content = await self.client.post_json(
+        return await self.client.post_json(
             destination=destination, path=path, data=query_content, timeout=timeout
         )
-        return content
 
     @log_function
-    async def query_user_devices(self, destination, user_id, timeout):
+    async def query_user_devices(
+        self, destination: str, user_id: str, timeout: int
+    ) -> JsonDict:
         """Query the devices for a user id hosted on a remote server.
 
         Response:
@@ -535,20 +551,21 @@ class TransportLayerClient:
             }
 
         Args:
-            destination(str): The server to query.
-            query_content(dict): The user ids to query.
+            destination: The server to query.
+            query_content: The user ids to query.
         Returns:
             A dict containing device and cross-signing keys.
         """
         path = _create_v1_path("/user/devices/%s", user_id)
 
-        content = await self.client.get_json(
+        return await self.client.get_json(
             destination=destination, path=path, timeout=timeout
         )
-        return content
 
     @log_function
-    async def claim_client_keys(self, destination, query_content, timeout):
+    async def claim_client_keys(
+        self, destination: str, query_content: JsonDict, timeout: int
+    ) -> JsonDict:
         """Claim one-time keys for a list of devices hosted on a remote server.
 
         Request:
@@ -572,33 +589,32 @@ class TransportLayerClient:
             }
 
         Args:
-            destination(str): The server to query.
-            query_content(dict): The user ids to query.
+            destination: The server to query.
+            query_content: The user ids to query.
         Returns:
             A dict containing the one-time keys.
         """
 
         path = _create_v1_path("/user/keys/claim")
 
-        content = await self.client.post_json(
+        return await self.client.post_json(
             destination=destination, path=path, data=query_content, timeout=timeout
         )
-        return content
 
     @log_function
     async def get_missing_events(
         self,
-        destination,
-        room_id,
-        earliest_events,
-        latest_events,
-        limit,
-        min_depth,
-        timeout,
-    ):
+        destination: str,
+        room_id: str,
+        earliest_events: Iterable[str],
+        latest_events: Iterable[str],
+        limit: int,
+        min_depth: int,
+        timeout: int,
+    ) -> JsonDict:
         path = _create_v1_path("/get_missing_events/%s", room_id)
 
-        content = await self.client.post_json(
+        return await self.client.post_json(
             destination=destination,
             path=path,
             data={
@@ -610,14 +626,14 @@ class TransportLayerClient:
             timeout=timeout,
         )
 
-        return content
-
     @log_function
-    def get_group_profile(self, destination, group_id, requester_user_id):
+    async def get_group_profile(
+        self, destination: str, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get a group profile"""
         path = _create_v1_path("/groups/%s/profile", group_id)
 
-        return self.client.get_json(
+        return await self.client.get_json(
             destination=destination,
             path=path,
             args={"requester_user_id": requester_user_id},
@@ -625,14 +641,16 @@ class TransportLayerClient:
         )
 
     @log_function
-    def update_group_profile(self, destination, group_id, requester_user_id, content):
+    async def update_group_profile(
+        self, destination: str, group_id: str, requester_user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Update a remote group profile
 
         Args:
-            destination (str)
-            group_id (str)
-            requester_user_id (str)
-            content (dict): The new profile of the group
+            destination
+            group_id
+            requester_user_id
+            content: The new profile of the group
         """
         path = _create_v1_path("/groups/%s/profile", group_id)
 
@@ -645,11 +663,13 @@ class TransportLayerClient:
         )
 
     @log_function
-    def get_group_summary(self, destination, group_id, requester_user_id):
+    async def get_group_summary(
+        self, destination: str, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get a group summary"""
         path = _create_v1_path("/groups/%s/summary", group_id)
 
-        return self.client.get_json(
+        return await self.client.get_json(
             destination=destination,
             path=path,
             args={"requester_user_id": requester_user_id},
@@ -657,24 +677,31 @@ class TransportLayerClient:
         )
 
     @log_function
-    def get_rooms_in_group(self, destination, group_id, requester_user_id):
+    async def get_rooms_in_group(
+        self, destination: str, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get all rooms in a group"""
         path = _create_v1_path("/groups/%s/rooms", group_id)
 
-        return self.client.get_json(
+        return await self.client.get_json(
             destination=destination,
             path=path,
             args={"requester_user_id": requester_user_id},
             ignore_backoff=True,
         )
 
-    def add_room_to_group(
-        self, destination, group_id, requester_user_id, room_id, content
-    ):
+    async def add_room_to_group(
+        self,
+        destination: str,
+        group_id: str,
+        requester_user_id: str,
+        room_id: str,
+        content: JsonDict,
+    ) -> JsonDict:
         """Add a room to a group"""
         path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
 
-        return self.client.post_json(
+        return await self.client.post_json(
             destination=destination,
             path=path,
             args={"requester_user_id": requester_user_id},
@@ -682,15 +709,21 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    def update_room_in_group(
-        self, destination, group_id, requester_user_id, room_id, config_key, content
-    ):
+    async def update_room_in_group(
+        self,
+        destination: str,
+        group_id: str,
+        requester_user_id: str,
+        room_id: str,
+        config_key: str,
+        content: JsonDict,
+    ) -> JsonDict:
         """Update room in group"""
         path = _create_v1_path(
             "/groups/%s/room/%s/config/%s", group_id, room_id, config_key
         )
 
-        return self.client.post_json(
+        return await self.client.post_json(
             destination=destination,
             path=path,
             args={"requester_user_id": requester_user_id},
@@ -698,11 +731,13 @@ class TransportLayerClient:
             ignore_backoff=True,
         )
 
-    def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
+    async def remove_room_from_group(
+        self, destination: str, group_id: str, requester_user_id: str, room_id: str
+    ) -> JsonDict:
         """Remove a room from a group"""
         path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
 
-        return self.client.delete_json(
+        return await self.client.delete_json(
             destination=destination,
             path=path,
             args={"requester_user_id": requester_user_id},
@@ -710,11 +745,13 @@ class TransportLayerClient:
         )
 
     @log_function
-    def get_users_in_group(self, destination, group_id, requester_user_id):
+    async def get_users_in_group(
+        self, destination: str, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get users in a group"""
         path = _create_v1_path("/groups/%s/users", group_id)
 
-        return self.client.get_json(
+        return await self.client.get_json(
             destination=destination,
             path=path,
             args={"requester_user_id": requester_user_id},
@@ -722,11 +759,13 @@ class TransportLayerClient:
         )
 
     @log_function
-    def get_invited_users_in_group(self, destination, group_id, requester_user_id):
+    async def get_invited_users_in_group(
+        self, destination: str, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get users that have been invited to a group"""
         path = _create_v1_path("/groups/%s/invited_users", group_id)
 
-        return self.client.get_json(
+        return await self.client.get_json(
             destination=destination,
             path=path,
             args={"requester_user_id": requester_user_id},
@@ -734,16 +773,20 @@ class TransportLayerClient:
         )
 
     @log_function
-    def accept_group_invite(self, destination, group_id, user_id, content):
+    async def accept_group_invite(
+        self, destination: str, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Accept a group invite"""
         path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id)
 
-        return self.client.post_json(
+        return await self.client.post_json(
             destination=destination, path=path, data=content, ignore_backoff=True
         )
 
     @log_function
-    def join_group(self, destination, group_id, user_id, content):
+    def join_group(
+        self, destination: str, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Attempts to join a group"""
         path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
 
@@ -752,13 +795,18 @@ class TransportLayerClient:
         )
 
     @log_function
-    def invite_to_group(
-        self, destination, group_id, user_id, requester_user_id, content
-    ):
+    async def invite_to_group(
+        self,
+        destination: str,
+        group_id: str,
+        user_id: str,
+        requester_user_id: str,
+        content: JsonDict,
+    ) -> JsonDict:
         """Invite a user to a group"""
         path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id)
 
-        return self.client.post_json(
+        return await self.client.post_json(
             destination=destination,
             path=path,
             args={"requester_user_id": requester_user_id},
@@ -767,25 +815,32 @@ class TransportLayerClient:
         )
 
     @log_function
-    def invite_to_group_notification(self, destination, group_id, user_id, content):
+    async def invite_to_group_notification(
+        self, destination: str, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Sent by group server to inform a user's server that they have been
         invited.
         """
 
         path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id)
 
-        return self.client.post_json(
+        return await self.client.post_json(
             destination=destination, path=path, data=content, ignore_backoff=True
         )
 
     @log_function
-    def remove_user_from_group(
-        self, destination, group_id, requester_user_id, user_id, content
-    ):
+    async def remove_user_from_group(
+        self,
+        destination: str,
+        group_id: str,
+        requester_user_id: str,
+        user_id: str,
+        content: JsonDict,
+    ) -> JsonDict:
         """Remove a user from a group"""
         path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
 
-        return self.client.post_json(
+        return await self.client.post_json(
             destination=destination,
             path=path,
             args={"requester_user_id": requester_user_id},
@@ -794,35 +849,43 @@ class TransportLayerClient:
         )
 
     @log_function
-    def remove_user_from_group_notification(
-        self, destination, group_id, user_id, content
-    ):
+    async def remove_user_from_group_notification(
+        self, destination: str, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Sent by group server to inform a user's server that they have been
         kicked from the group.
         """
 
         path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id)
 
-        return self.client.post_json(
+        return await self.client.post_json(
             destination=destination, path=path, data=content, ignore_backoff=True
         )
 
     @log_function
-    def renew_group_attestation(self, destination, group_id, user_id, content):
+    async def renew_group_attestation(
+        self, destination: str, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Sent by either a group server or a user's server to periodically update
         the attestations
         """
 
         path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id)
 
-        return self.client.post_json(
+        return await self.client.post_json(
             destination=destination, path=path, data=content, ignore_backoff=True
         )
 
     @log_function
-    def update_group_summary_room(
-        self, destination, group_id, user_id, room_id, category_id, content
-    ):
+    async def update_group_summary_room(
+        self,
+        destination: str,
+        group_id: str,
+        user_id: str,
+        room_id: str,
+        category_id: str,
+        content: JsonDict,
+    ) -> JsonDict:
         """Update a room entry in a group summary"""
         if category_id:
             path = _create_v1_path(
@@ -834,7 +897,7 @@ class TransportLayerClient:
         else:
             path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id)
 
-        return self.client.post_json(
+        return await self.client.post_json(
             destination=destination,
             path=path,
             args={"requester_user_id": user_id},
@@ -843,9 +906,14 @@ class TransportLayerClient:
         )
 
     @log_function
-    def delete_group_summary_room(
-        self, destination, group_id, user_id, room_id, category_id
-    ):
+    async def delete_group_summary_room(
+        self,
+        destination: str,
+        group_id: str,
+        user_id: str,
+        room_id: str,
+        category_id: str,
+    ) -> JsonDict:
         """Delete a room entry in a group summary"""
         if category_id:
             path = _create_v1_path(
@@ -857,7 +925,7 @@ class TransportLayerClient:
         else:
             path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id)
 
-        return self.client.delete_json(
+        return await self.client.delete_json(
             destination=destination,
             path=path,
             args={"requester_user_id": user_id},
@@ -865,11 +933,13 @@ class TransportLayerClient:
         )
 
     @log_function
-    def get_group_categories(self, destination, group_id, requester_user_id):
+    async def get_group_categories(
+        self, destination: str, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get all categories in a group"""
         path = _create_v1_path("/groups/%s/categories", group_id)
 
-        return self.client.get_json(
+        return await self.client.get_json(
             destination=destination,
             path=path,
             args={"requester_user_id": requester_user_id},
@@ -877,11 +947,13 @@ class TransportLayerClient:
         )
 
     @log_function
-    def get_group_category(self, destination, group_id, requester_user_id, category_id):
+    async def get_group_category(
+        self, destination: str, group_id: str, requester_user_id: str, category_id: str
+    ) -> JsonDict:
         """Get category info in a group"""
         path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
 
-        return self.client.get_json(
+        return await self.client.get_json(
             destination=destination,
             path=path,
             args={"requester_user_id": requester_user_id},
@@ -889,13 +961,18 @@ class TransportLayerClient:
         )
 
     @log_function
-    def update_group_category(
-        self, destination, group_id, requester_user_id, category_id, content
-    ):
+    async def update_group_category(
+        self,
+        destination: str,
+        group_id: str,
+        requester_user_id: str,
+        category_id: str,
+        content: JsonDict,
+    ) -> JsonDict:
         """Update a category in a group"""
         path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
 
-        return self.client.post_json(
+        return await self.client.post_json(
             destination=destination,
             path=path,
             args={"requester_user_id": requester_user_id},
@@ -904,13 +981,13 @@ class TransportLayerClient:
         )
 
     @log_function
-    def delete_group_category(
-        self, destination, group_id, requester_user_id, category_id
-    ):
+    async def delete_group_category(
+        self, destination: str, group_id: str, requester_user_id: str, category_id: str
+    ) -> JsonDict:
         """Delete a category in a group"""
         path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
 
-        return self.client.delete_json(
+        return await self.client.delete_json(
             destination=destination,
             path=path,
             args={"requester_user_id": requester_user_id},
@@ -918,11 +995,13 @@ class TransportLayerClient:
         )
 
     @log_function
-    def get_group_roles(self, destination, group_id, requester_user_id):
+    async def get_group_roles(
+        self, destination: str, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get all roles in a group"""
         path = _create_v1_path("/groups/%s/roles", group_id)
 
-        return self.client.get_json(
+        return await self.client.get_json(
             destination=destination,
             path=path,
             args={"requester_user_id": requester_user_id},
@@ -930,11 +1009,13 @@ class TransportLayerClient:
         )
 
     @log_function
-    def get_group_role(self, destination, group_id, requester_user_id, role_id):
+    async def get_group_role(
+        self, destination: str, group_id: str, requester_user_id: str, role_id: str
+    ) -> JsonDict:
         """Get a roles info"""
         path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
 
-        return self.client.get_json(
+        return await self.client.get_json(
             destination=destination,
             path=path,
             args={"requester_user_id": requester_user_id},
@@ -942,13 +1023,18 @@ class TransportLayerClient:
         )
 
     @log_function
-    def update_group_role(
-        self, destination, group_id, requester_user_id, role_id, content
-    ):
+    async def update_group_role(
+        self,
+        destination: str,
+        group_id: str,
+        requester_user_id: str,
+        role_id: str,
+        content: JsonDict,
+    ) -> JsonDict:
         """Update a role in a group"""
         path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
 
-        return self.client.post_json(
+        return await self.client.post_json(
             destination=destination,
             path=path,
             args={"requester_user_id": requester_user_id},
@@ -957,11 +1043,13 @@ class TransportLayerClient:
         )
 
     @log_function
-    def delete_group_role(self, destination, group_id, requester_user_id, role_id):
+    async def delete_group_role(
+        self, destination: str, group_id: str, requester_user_id: str, role_id: str
+    ) -> JsonDict:
         """Delete a role in a group"""
         path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
 
-        return self.client.delete_json(
+        return await self.client.delete_json(
             destination=destination,
             path=path,
             args={"requester_user_id": requester_user_id},
@@ -969,9 +1057,15 @@ class TransportLayerClient:
         )
 
     @log_function
-    def update_group_summary_user(
-        self, destination, group_id, requester_user_id, user_id, role_id, content
-    ):
+    async def update_group_summary_user(
+        self,
+        destination: str,
+        group_id: str,
+        requester_user_id: str,
+        user_id: str,
+        role_id: str,
+        content: JsonDict,
+    ) -> JsonDict:
         """Update a users entry in a group"""
         if role_id:
             path = _create_v1_path(
@@ -980,7 +1074,7 @@ class TransportLayerClient:
         else:
             path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id)
 
-        return self.client.post_json(
+        return await self.client.post_json(
             destination=destination,
             path=path,
             args={"requester_user_id": requester_user_id},
@@ -989,11 +1083,13 @@ class TransportLayerClient:
         )
 
     @log_function
-    def set_group_join_policy(self, destination, group_id, requester_user_id, content):
+    async def set_group_join_policy(
+        self, destination: str, group_id: str, requester_user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Sets the join policy for a group"""
         path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id)
 
-        return self.client.put_json(
+        return await self.client.put_json(
             destination=destination,
             path=path,
             args={"requester_user_id": requester_user_id},
@@ -1002,9 +1098,14 @@ class TransportLayerClient:
         )
 
     @log_function
-    def delete_group_summary_user(
-        self, destination, group_id, requester_user_id, user_id, role_id
-    ):
+    async def delete_group_summary_user(
+        self,
+        destination: str,
+        group_id: str,
+        requester_user_id: str,
+        user_id: str,
+        role_id: str,
+    ) -> JsonDict:
         """Delete a users entry in a group"""
         if role_id:
             path = _create_v1_path(
@@ -1013,33 +1114,35 @@ class TransportLayerClient:
         else:
             path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id)
 
-        return self.client.delete_json(
+        return await self.client.delete_json(
             destination=destination,
             path=path,
             args={"requester_user_id": requester_user_id},
             ignore_backoff=True,
         )
 
-    def bulk_get_publicised_groups(self, destination, user_ids):
+    async def bulk_get_publicised_groups(
+        self, destination: str, user_ids: Iterable[str]
+    ) -> JsonDict:
         """Get the groups a list of users are publicising"""
 
         path = _create_v1_path("/get_groups_publicised")
 
         content = {"user_ids": user_ids}
 
-        return self.client.post_json(
+        return await self.client.post_json(
             destination=destination, path=path, data=content, ignore_backoff=True
         )
 
-    def get_room_complexity(self, destination, room_id):
+    async def get_room_complexity(self, destination: str, room_id: str) -> JsonDict:
         """
         Args:
-            destination (str): The remote server
-            room_id (str): The room ID to ask about.
+            destination: The remote server
+            room_id: The room ID to ask about.
         """
         path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/rooms/%s/complexity", room_id)
 
-        return self.client.get_json(destination=destination, path=path)
+        return await self.client.get_json(destination=destination, path=path)
 
     async def get_space_summary(
         self,
@@ -1075,14 +1178,14 @@ class TransportLayerClient:
         )
 
 
-def _create_path(federation_prefix, path, *args):
+def _create_path(federation_prefix: str, path: str, *args: str) -> str:
     """
     Ensures that all args are url encoded.
     """
     return federation_prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args)
 
 
-def _create_v1_path(path, *args):
+def _create_v1_path(path: str, *args: str) -> str:
     """Creates a path against V1 federation API from the path template and
     args. Ensures that all args are url encoded.
 
@@ -1091,16 +1194,13 @@ def _create_v1_path(path, *args):
         _create_v1_path("/event/%s", event_id)
 
     Args:
-        path (str): String template for the path
-        args: ([str]): Args to insert into path. Each arg will be url encoded
-
-    Returns:
-        str
+        path: String template for the path
+        args: Args to insert into path. Each arg will be url encoded
     """
     return _create_path(FEDERATION_V1_PREFIX, path, *args)
 
 
-def _create_v2_path(path, *args):
+def _create_v2_path(path: str, *args: str) -> str:
     """Creates a path against V2 federation API from the path template and
     args. Ensures that all args are url encoded.
 
@@ -1109,11 +1209,8 @@ def _create_v2_path(path, *args):
         _create_v2_path("/event/%s", event_id)
 
     Args:
-        path (str): String template for the path
-        args: ([str]): Args to insert into path. Each arg will be url encoded
-
-    Returns:
-        str
+        path: String template for the path
+        args: Args to insert into path. Each arg will be url encoded
     """
     return _create_path(FEDERATION_V2_PREFIX, path, *args)
 
@@ -1122,8 +1219,26 @@ def _create_v2_path(path, *args):
 class SendJoinResponse:
     """The parsed response of a `/send_join` request."""
 
+    # The list of auth events from the /send_join response.
     auth_events: List[EventBase]
+    # The list of state from the /send_join response.
     state: List[EventBase]
+    # The raw join event from the /send_join response.
+    event_dict: JsonDict
+    # The parsed join event from the /send_join response. This will be None if
+    # "event" is not included in the response.
+    event: Optional[EventBase] = None
+
+
+@ijson.coroutine
+def _event_parser(event_dict: JsonDict):
+    """Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
+    to add them to a given dictionary.
+    """
+
+    while True:
+        key, value = yield
+        event_dict[key] = value
 
 
 @ijson.coroutine
@@ -1149,7 +1264,8 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
     CONTENT_TYPE = "application/json"
 
     def __init__(self, room_version: RoomVersion, v1_api: bool):
-        self._response = SendJoinResponse([], [])
+        self._response = SendJoinResponse([], [], {})
+        self._room_version = room_version
 
         # The V1 API has the shape of `[200, {...}]`, which we handle by
         # prefixing with `item.*`.
@@ -1163,12 +1279,21 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
             _event_list_parser(room_version, self._response.auth_events),
             prefix + "auth_chain.item",
         )
+        self._coro_event = ijson.kvitems_coro(
+            _event_parser(self._response.event_dict),
+            prefix + "org.matrix.msc3083.v2.event",
+        )
 
     def write(self, data: bytes) -> int:
         self._coro_state.send(data)
         self._coro_auth.send(data)
+        self._coro_event.send(data)
 
         return len(data)
 
     def finish(self) -> SendJoinResponse:
+        if self._response.event_dict:
+            self._response.event = make_event_from_dict(
+                self._response.event_dict, self._room_version
+            )
         return self._response
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 2974d4d0cc..5e059d6e09 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -984,7 +984,7 @@ class PublicRoomList(BaseFederationServlet):
         limit = parse_integer_from_args(query, "limit", 0)
         since_token = parse_string_from_args(query, "since", None)
         include_all_networks = parse_boolean_from_args(
-            query, "include_all_networks", False
+            query, "include_all_networks", default=False
         )
         third_party_instance_id = parse_string_from_args(
             query, "third_party_instance_id", None
@@ -1908,16 +1908,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet):
         suggested_only = parse_boolean_from_args(query, "suggested_only", default=False)
         max_rooms_per_space = parse_integer_from_args(query, "max_rooms_per_space")
 
-        exclude_rooms = []
-        if b"exclude_rooms" in query:
-            try:
-                exclude_rooms = [
-                    room_id.decode("ascii") for room_id in query[b"exclude_rooms"]
-                ]
-            except Exception:
-                raise SynapseError(
-                    400, "Bad query parameter for exclude_rooms", Codes.INVALID_PARAM
-                )
+        exclude_rooms = parse_strings_from_args(query, "exclude_rooms", default=[])
 
         return 200, await self.handler.federation_space_summary(
             origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 41dbdfd0a1..53fac1f8a3 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
 from typing import TYPE_CHECKING, Collection, List, Optional, Union
 
 from synapse import event_auth
@@ -20,16 +21,18 @@ from synapse.api.constants import (
     Membership,
     RestrictedJoinRuleTypes,
 )
-from synapse.api.errors import AuthError
+from synapse.api.errors import AuthError, Codes, SynapseError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.events import EventBase
 from synapse.events.builder import EventBuilder
-from synapse.types import StateMap
+from synapse.types import StateMap, get_domain_from_id
 from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
+logger = logging.getLogger(__name__)
+
 
 class EventAuthHandler:
     """
@@ -39,6 +42,7 @@ class EventAuthHandler:
     def __init__(self, hs: "HomeServer"):
         self._clock = hs.get_clock()
         self._store = hs.get_datastore()
+        self._server_name = hs.hostname
 
     async def check_from_context(
         self, room_version: str, event, context, do_sig_check=True
@@ -81,15 +85,76 @@ class EventAuthHandler:
         # introduce undesirable "state reset" behaviour.
         #
         # All of which sounds a bit tricky so we don't bother for now.
-
         auth_ids = []
-        for etype, state_key in event_auth.auth_types_for_event(event):
+        for etype, state_key in event_auth.auth_types_for_event(
+            event.room_version, event
+        ):
             auth_ev_id = current_state_ids.get((etype, state_key))
             if auth_ev_id:
                 auth_ids.append(auth_ev_id)
 
         return auth_ids
 
+    async def get_user_which_could_invite(
+        self, room_id: str, current_state_ids: StateMap[str]
+    ) -> str:
+        """
+        Searches the room state for a local user who has the power level necessary
+        to invite other users.
+
+        Args:
+            room_id: The room ID under search.
+            current_state_ids: The current state of the room.
+
+        Returns:
+            The MXID of the user which could issue an invite.
+
+        Raises:
+            SynapseError if no appropriate user is found.
+        """
+        power_level_event_id = current_state_ids.get((EventTypes.PowerLevels, ""))
+        invite_level = 0
+        users_default_level = 0
+        if power_level_event_id:
+            power_level_event = await self._store.get_event(power_level_event_id)
+            invite_level = power_level_event.content.get("invite", invite_level)
+            users_default_level = power_level_event.content.get(
+                "users_default", users_default_level
+            )
+            users = power_level_event.content.get("users", {})
+        else:
+            users = {}
+
+        # Find the user with the highest power level.
+        users_in_room = await self._store.get_users_in_room(room_id)
+        # Only interested in local users.
+        local_users_in_room = [
+            u for u in users_in_room if get_domain_from_id(u) == self._server_name
+        ]
+        chosen_user = max(
+            local_users_in_room,
+            key=lambda user: users.get(user, users_default_level),
+            default=None,
+        )
+
+        # Return the chosen if they can issue invites.
+        user_power_level = users.get(chosen_user, users_default_level)
+        if chosen_user and user_power_level >= invite_level:
+            logger.debug(
+                "Found a user who can issue invites  %s with power level %d >= invite level %d",
+                chosen_user,
+                user_power_level,
+                invite_level,
+            )
+            return chosen_user
+
+        # No user was found.
+        raise SynapseError(
+            400,
+            "Unable to find a user which could issue an invite",
+            Codes.UNABLE_TO_GRANT_JOIN,
+        )
+
     async def check_host_in_room(self, room_id: str, host: str) -> bool:
         with Measure(self._clock, "check_host_in_room"):
             return await self._store.is_host_joined(room_id, host)
@@ -134,6 +199,18 @@ class EventAuthHandler:
         # in any of them.
         allowed_rooms = await self.get_rooms_that_allow_join(state_ids)
         if not await self.is_user_in_rooms(allowed_rooms, user_id):
+
+            # If this is a remote request, the user might be in an allowed room
+            # that we do not know about.
+            if get_domain_from_id(user_id) != self._server_name:
+                for room_id in allowed_rooms:
+                    if not await self._store.is_host_joined(room_id, self._server_name):
+                        raise SynapseError(
+                            400,
+                            f"Unable to check if {user_id} is in allowed rooms.",
+                            Codes.UNABLE_AUTHORISE_JOIN,
+                        )
+
             raise AuthError(
                 403,
                 "You do not belong to any of the required rooms to join this room.",
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 5728719909..8197b60b76 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1494,9 +1494,10 @@ class FederationHandler(BaseHandler):
                 host_list, event, room_version_obj
             )
 
-            origin = ret["origin"]
-            state = ret["state"]
-            auth_chain = ret["auth_chain"]
+            event = ret.event
+            origin = ret.origin
+            state = ret.state
+            auth_chain = ret.auth_chain
             auth_chain.sort(key=lambda e: e.depth)
 
             logger.debug("do_invite_join auth_chain: %s", auth_chain)
@@ -1676,7 +1677,7 @@ class FederationHandler(BaseHandler):
 
         # checking the room version will check that we've actually heard of the room
         # (and return a 404 otherwise)
-        room_version = await self.store.get_room_version_id(room_id)
+        room_version = await self.store.get_room_version(room_id)
 
         # now check that we are *still* in the room
         is_in_room = await self._event_auth_handler.check_host_in_room(
@@ -1691,8 +1692,38 @@ class FederationHandler(BaseHandler):
 
         event_content = {"membership": Membership.JOIN}
 
+        # If the current room is using restricted join rules, additional information
+        # may need to be included in the event content in order to efficiently
+        # validate the event.
+        #
+        # Note that this requires the /send_join request to come back to the
+        # same server.
+        if room_version.msc3083_join_rules:
+            state_ids = await self.store.get_current_state_ids(room_id)
+            if await self._event_auth_handler.has_restricted_join_rules(
+                state_ids, room_version
+            ):
+                prev_member_event_id = state_ids.get((EventTypes.Member, user_id), None)
+                # If the user is invited or joined to the room already, then
+                # no additional info is needed.
+                include_auth_user_id = True
+                if prev_member_event_id:
+                    prev_member_event = await self.store.get_event(prev_member_event_id)
+                    include_auth_user_id = prev_member_event.membership not in (
+                        Membership.JOIN,
+                        Membership.INVITE,
+                    )
+
+                if include_auth_user_id:
+                    event_content[
+                        "join_authorised_via_users_server"
+                    ] = await self._event_auth_handler.get_user_which_could_invite(
+                        room_id,
+                        state_ids,
+                    )
+
         builder = self.event_builder_factory.new(
-            room_version,
+            room_version.identifier,
             {
                 "type": EventTypes.Member,
                 "content": event_content,
@@ -1710,10 +1741,13 @@ class FederationHandler(BaseHandler):
             logger.warning("Failed to create join to %s because %s", room_id, e)
             raise
 
+        # Ensure the user can even join the room.
+        await self._check_join_restrictions(context, event)
+
         # The remote hasn't signed it yet, obviously. We'll do the full checks
         # when we get the event back in `on_send_join_request`
         await self._event_auth_handler.check_from_context(
-            room_version, event, context, do_sig_check=False
+            room_version.identifier, event, context, do_sig_check=False
         )
 
         return event
@@ -1958,7 +1992,7 @@ class FederationHandler(BaseHandler):
     @log_function
     async def on_send_membership_event(
         self, origin: str, event: EventBase
-    ) -> EventContext:
+    ) -> Tuple[EventBase, EventContext]:
         """
         We have received a join/leave/knock event for a room via send_join/leave/knock.
 
@@ -1981,7 +2015,7 @@ class FederationHandler(BaseHandler):
             event: The member event that has been signed by the remote homeserver.
 
         Returns:
-            The context of the event after inserting it into the room graph.
+            The event and context of the event after inserting it into the room graph.
 
         Raises:
             SynapseError if the event is not accepted into the room
@@ -2037,7 +2071,7 @@ class FederationHandler(BaseHandler):
 
         # all looks good, we can persist the event.
         await self._run_push_actions_and_persist_event(event, context)
-        return context
+        return event, context
 
     async def _check_join_restrictions(
         self, context: EventContext, event: EventBase
@@ -2473,7 +2507,7 @@ class FederationHandler(BaseHandler):
         )
 
         # Now check if event pass auth against said current state
-        auth_types = auth_types_for_event(event)
+        auth_types = auth_types_for_event(room_version_obj, event)
         current_state_ids_list = [
             e for k, e in current_state_ids.items() if k in auth_types
         ]
@@ -2714,9 +2748,11 @@ class FederationHandler(BaseHandler):
                             event.event_id,
                             e.event_id,
                         )
-                        context = await self.state_handler.compute_event_context(e)
+                        missing_auth_event_context = (
+                            await self.state_handler.compute_event_context(e)
+                        )
                         await self._auth_and_persist_event(
-                            origin, e, context, auth_events=auth
+                            origin, e, missing_auth_event_context, auth_events=auth
                         )
 
                         if e.event_id in event_auth_events:
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 5d49640760..e1c544a3c9 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -21,6 +21,7 @@ from synapse.api.constants import EduTypes, EventTypes, Membership
 from synapse.api.errors import SynapseError
 from synapse.events.validator import EventValidator
 from synapse.handlers.presence import format_user_presence_state
+from synapse.handlers.receipts import ReceiptEventSource
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage.roommember import RoomsForUser
 from synapse.streams.config import PaginationConfig
@@ -134,6 +135,8 @@ class InitialSyncHandler(BaseHandler):
             joined_rooms,
             to_key=int(now_token.receipt_key),
         )
+        if self.hs.config.experimental.msc2285_enabled:
+            receipt = ReceiptEventSource.filter_out_hidden(receipt, user_id)
 
         tags_by_room = await self.store.get_tags_for_user(user_id)
 
@@ -430,7 +433,9 @@ class InitialSyncHandler(BaseHandler):
                 room_id, to_key=now_token.receipt_key
             )
             if not receipts:
-                receipts = []
+                return []
+            if self.hs.config.experimental.msc2285_enabled:
+                receipts = ReceiptEventSource.filter_out_hidden(receipts, user_id)
             return receipts
 
         presence, receipts, (messages, token) = await make_deferred_yieldable(
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 283483fc2c..b9085bbccb 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -14,9 +14,10 @@
 import logging
 from typing import TYPE_CHECKING, List, Optional, Tuple
 
+from synapse.api.constants import ReadReceiptEventFields
 from synapse.appservice import ApplicationService
 from synapse.handlers._base import BaseHandler
-from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
+from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -137,7 +138,7 @@ class ReceiptsHandler(BaseHandler):
         return True
 
     async def received_client_receipt(
-        self, room_id: str, receipt_type: str, user_id: str, event_id: str
+        self, room_id: str, receipt_type: str, user_id: str, event_id: str, hidden: bool
     ) -> None:
         """Called when a client tells us a local user has read up to the given
         event_id in the room.
@@ -147,23 +148,67 @@ class ReceiptsHandler(BaseHandler):
             receipt_type=receipt_type,
             user_id=user_id,
             event_ids=[event_id],
-            data={"ts": int(self.clock.time_msec())},
+            data={"ts": int(self.clock.time_msec()), "hidden": hidden},
         )
 
         is_new = await self._handle_new_receipts([receipt])
         if not is_new:
             return
 
-        if self.federation_sender:
+        if self.federation_sender and not (
+            self.hs.config.experimental.msc2285_enabled and hidden
+        ):
             await self.federation_sender.send_read_receipt(receipt)
 
 
 class ReceiptEventSource:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
+        self.config = hs.config
+
+    @staticmethod
+    def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]:
+        visible_events = []
+
+        # filter out hidden receipts the user shouldn't see
+        for event in events:
+            content = event.get("content", {})
+            new_event = event.copy()
+            new_event["content"] = {}
+
+            for event_id in content.keys():
+                event_content = content.get(event_id, {})
+                m_read = event_content.get("m.read", {})
+
+                # If m_read is missing copy over the original event_content as there is nothing to process here
+                if not m_read:
+                    new_event["content"][event_id] = event_content.copy()
+                    continue
+
+                new_users = {}
+                for rr_user_id, user_rr in m_read.items():
+                    hidden = user_rr.get("hidden", None)
+                    if hidden is not True or rr_user_id == user_id:
+                        new_users[rr_user_id] = user_rr.copy()
+                        # If hidden has a value replace hidden with the correct prefixed key
+                        if hidden is not None:
+                            new_users[rr_user_id].pop("hidden")
+                            new_users[rr_user_id][
+                                ReadReceiptEventFields.MSC2285_HIDDEN
+                            ] = hidden
+
+                # Set new users unless empty
+                if len(new_users.keys()) > 0:
+                    new_event["content"][event_id] = {"m.read": new_users}
+
+            # Append new_event to visible_events unless empty
+            if len(new_event["content"].keys()) > 0:
+                visible_events.append(new_event)
+
+        return visible_events
 
     async def get_new_events(
-        self, from_key: int, room_ids: List[str], **kwargs
+        self, from_key: int, room_ids: List[str], user: UserID, **kwargs
     ) -> Tuple[List[JsonDict], int]:
         from_key = int(from_key)
         to_key = self.get_current_key()
@@ -175,6 +220,9 @@ class ReceiptEventSource:
             room_ids, from_key=from_key, to_key=to_key
         )
 
+        if self.config.experimental.msc2285_enabled:
+            events = ReceiptEventSource.filter_out_hidden(events, user.to_string())
+
         return (events, to_key)
 
     async def get_new_events_as(
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 370561e549..b33fe09f77 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -951,6 +951,7 @@ class RoomCreationHandler(BaseHandler):
                 "kick": 50,
                 "redact": 50,
                 "invite": 50,
+                "historical": 100,
             }
 
             if config["original_invitees_have_ops"]:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 1192591609..65ad3efa6a 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -16,7 +16,7 @@ import abc
 import logging
 import random
 from http import HTTPStatus
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple
 
 from synapse import types
 from synapse.api.constants import AccountDataTypes, EventTypes, Membership
@@ -28,6 +28,7 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.api.ratelimiting import Ratelimiter
+from synapse.event_auth import get_named_level, get_power_level_event
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.types import (
@@ -340,16 +341,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         if event.membership == Membership.JOIN:
             newly_joined = True
-            prev_member_event = None
             if prev_member_event_id:
                 prev_member_event = await self.store.get_event(prev_member_event_id)
                 newly_joined = prev_member_event.membership != Membership.JOIN
 
-            # Check if the member should be allowed access via membership in a space.
-            await self.event_auth_handler.check_restricted_join_rules(
-                prev_state_ids, event.room_version, user_id, prev_member_event
-            )
-
             # Only rate-limit if the user actually joined the room, otherwise we'll end
             # up blocking profile updates.
             if newly_joined and ratelimit:
@@ -701,7 +696,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     # so don't really fit into the general auth process.
                     raise AuthError(403, "Guest access not allowed")
 
-            if not is_host_in_room:
+            # Check if a remote join should be performed.
+            remote_join, remote_room_hosts = await self._should_perform_remote_join(
+                target.to_string(), room_id, remote_room_hosts, content, is_host_in_room
+            )
+            if remote_join:
                 if ratelimit:
                     time_now_s = self.clock.time()
                     (
@@ -826,6 +825,106 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             outlier=outlier,
         )
 
+    async def _should_perform_remote_join(
+        self,
+        user_id: str,
+        room_id: str,
+        remote_room_hosts: List[str],
+        content: JsonDict,
+        is_host_in_room: bool,
+    ) -> Tuple[bool, List[str]]:
+        """
+        Check whether the server should do a remote join (as opposed to a local
+        join) for a user.
+
+        Generally a remote join is used if:
+
+        * The server is not yet in the room.
+        * The server is in the room, the room has restricted join rules, the user
+          is not joined or invited to the room, and the server does not have
+          another user who is capable of issuing invites.
+
+        Args:
+            user_id: The user joining the room.
+            room_id: The room being joined.
+            remote_room_hosts: A list of remote room hosts.
+            content: The content to use as the event body of the join. This may
+                be modified.
+            is_host_in_room: True if the host is in the room.
+
+        Returns:
+            A tuple of:
+                True if a remote join should be performed. False if the join can be
+                done locally.
+
+                A list of remote room hosts to use. This is an empty list if a
+                local join is to be done.
+        """
+        # If the host isn't in the room, pass through the prospective hosts.
+        if not is_host_in_room:
+            return True, remote_room_hosts
+
+        # If the host is in the room, but not one of the authorised hosts
+        # for restricted join rules, a remote join must be used.
+        room_version = await self.store.get_room_version(room_id)
+        current_state_ids = await self.store.get_current_state_ids(room_id)
+
+        # If restricted join rules are not being used, a local join can always
+        # be used.
+        if not await self.event_auth_handler.has_restricted_join_rules(
+            current_state_ids, room_version
+        ):
+            return False, []
+
+        # If the user is invited to the room or already joined, the join
+        # event can always be issued locally.
+        prev_member_event_id = current_state_ids.get((EventTypes.Member, user_id), None)
+        prev_member_event = None
+        if prev_member_event_id:
+            prev_member_event = await self.store.get_event(prev_member_event_id)
+            if prev_member_event.membership in (
+                Membership.JOIN,
+                Membership.INVITE,
+            ):
+                return False, []
+
+        # If the local host has a user who can issue invites, then a local
+        # join can be done.
+        #
+        # If not, generate a new list of remote hosts based on which
+        # can issue invites.
+        event_map = await self.store.get_events(current_state_ids.values())
+        current_state = {
+            state_key: event_map[event_id]
+            for state_key, event_id in current_state_ids.items()
+        }
+        allowed_servers = get_servers_from_users(
+            get_users_which_can_issue_invite(current_state)
+        )
+
+        # If the local server is not one of allowed servers, then a remote
+        # join must be done. Return the list of prospective servers based on
+        # which can issue invites.
+        if self.hs.hostname not in allowed_servers:
+            return True, list(allowed_servers)
+
+        # Ensure the member should be allowed access via membership in a room.
+        await self.event_auth_handler.check_restricted_join_rules(
+            current_state_ids, room_version, user_id, prev_member_event
+        )
+
+        # If this is going to be a local join, additional information must
+        # be included in the event content in order to efficiently validate
+        # the event.
+        content[
+            "join_authorised_via_users_server"
+        ] = await self.event_auth_handler.get_user_which_could_invite(
+            room_id,
+            current_state_ids,
+        )
+
+        return False, []
+
     async def transfer_room_state_on_room_upgrade(
         self, old_room_id: str, room_id: str
     ) -> None:
@@ -1514,3 +1613,63 @@ class RoomMemberMasterHandler(RoomMemberHandler):
 
         if membership:
             await self.store.forget(user_id, room_id)
+
+
+def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[str]:
+    """
+    Return the list of users which can issue invites.
+
+    This is done by exploring the joined users and comparing their power levels
+    to the necessyar power level to issue an invite.
+
+    Args:
+        auth_events: state in force at this point in the room
+
+    Returns:
+        The users which can issue invites.
+    """
+    invite_level = get_named_level(auth_events, "invite", 0)
+    users_default_level = get_named_level(auth_events, "users_default", 0)
+    power_level_event = get_power_level_event(auth_events)
+
+    # Custom power-levels for users.
+    if power_level_event:
+        users = power_level_event.content.get("users", {})
+    else:
+        users = {}
+
+    result = []
+
+    # Check which members are able to invite by ensuring they're joined and have
+    # the necessary power level.
+    for (event_type, state_key), event in auth_events.items():
+        if event_type != EventTypes.Member:
+            continue
+
+        if event.membership != Membership.JOIN:
+            continue
+
+        # Check if the user has a custom power level.
+        if users.get(state_key, users_default_level) >= invite_level:
+            result.append(state_key)
+
+    return result
+
+
+def get_servers_from_users(users: List[str]) -> Set[str]:
+    """
+    Resolve a list of users into their servers.
+
+    Args:
+        users: A list of users.
+
+    Returns:
+        A set of servers.
+    """
+    servers = set()
+    for user in users:
+        try:
+            servers.add(get_domain_from_id(user))
+        except SynapseError:
+            pass
+    return servers
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 2ac76b15c2..c2ea51ee16 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -847,7 +847,7 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
 
 def read_body_with_max_size(
     response: IResponse, stream: ByteWriteable, max_size: Optional[int]
-) -> defer.Deferred:
+) -> "defer.Deferred[int]":
     """
     Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
 
@@ -862,7 +862,7 @@ def read_body_with_max_size(
     Returns:
         A Deferred which resolves to the length of the read body.
     """
-    d = defer.Deferred()
+    d: "defer.Deferred[int]" = defer.Deferred()
 
     # If the Content-Length header gives a size larger than the maximum allowed
     # size, do not bother downloading the body.
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 950770201a..c16b7f10e6 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -27,7 +27,7 @@ from twisted.internet.interfaces import (
 )
 from twisted.web.client import URI, Agent, HTTPConnectionPool
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer
+from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer, IResponse
 
 from synapse.crypto.context_factory import FederationPolicyForHTTPS
 from synapse.http.client import BlacklistingAgentWrapper
@@ -116,7 +116,7 @@ class MatrixFederationAgent:
         uri: bytes,
         headers: Optional[Headers] = None,
         bodyProducer: Optional[IBodyProducer] = None,
-    ) -> Generator[defer.Deferred, Any, defer.Deferred]:
+    ) -> Generator[defer.Deferred, Any, IResponse]:
         """
         Args:
             method: HTTP method: GET/POST/etc
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index f7193e60bd..19e987f118 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -14,21 +14,32 @@
 import base64
 import logging
 import re
-from typing import Optional, Tuple
-from urllib.request import getproxies_environment, proxy_bypass_environment
+from typing import Any, Dict, Optional, Tuple
+from urllib.parse import urlparse
+from urllib.request import (  # type: ignore[attr-defined]
+    getproxies_environment,
+    proxy_bypass_environment,
+)
 
 import attr
 from zope.interface import implementer
 
 from twisted.internet import defer
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
 from twisted.python.failure import Failure
-from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
+from twisted.web.client import (
+    URI,
+    BrowserLikePolicyForHTTPS,
+    HTTPConnectionPool,
+    _AgentBase,
+)
 from twisted.web.error import SchemeNotSupported
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IPolicyForHTTPS
+from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS
 
 from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
+from synapse.types import ISynapseReactor
 
 logger = logging.getLogger(__name__)
 
@@ -63,35 +74,38 @@ class ProxyAgent(_AgentBase):
                        reactor might have some blacklisting applied (i.e. for DNS queries),
                        but we need unblocked access to the proxy.
 
-        contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the
+        contextFactory: A factory for TLS contexts, to control the
             verification parameters of OpenSSL.  The default is to use a
             `BrowserLikePolicyForHTTPS`, so unless you have special
             requirements you can leave this as-is.
 
-        connectTimeout (Optional[float]): The amount of time that this Agent will wait
+        connectTimeout: The amount of time that this Agent will wait
             for the peer to accept a connection, in seconds. If 'None',
             HostnameEndpoint's default (30s) will be used.
-
             This is used for connections to both proxies and destination servers.
 
-        bindAddress (bytes): The local address for client sockets to bind to.
+        bindAddress: The local address for client sockets to bind to.
 
-        pool (HTTPConnectionPool|None): connection pool to be used. If None, a
+        pool: connection pool to be used. If None, a
             non-persistent pool instance will be created.
 
-        use_proxy (bool): Whether proxy settings should be discovered and used
+        use_proxy: Whether proxy settings should be discovered and used
             from conventional environment variables.
+
+    Raises:
+        ValueError if use_proxy is set and the environment variables
+            contain an invalid proxy specification.
     """
 
     def __init__(
         self,
-        reactor,
-        proxy_reactor=None,
+        reactor: IReactorCore,
+        proxy_reactor: Optional[ISynapseReactor] = None,
         contextFactory: Optional[IPolicyForHTTPS] = None,
-        connectTimeout=None,
-        bindAddress=None,
-        pool=None,
-        use_proxy=False,
+        connectTimeout: Optional[float] = None,
+        bindAddress: Optional[bytes] = None,
+        pool: Optional[HTTPConnectionPool] = None,
+        use_proxy: bool = False,
     ):
         contextFactory = contextFactory or BrowserLikePolicyForHTTPS()
 
@@ -102,7 +116,7 @@ class ProxyAgent(_AgentBase):
         else:
             self.proxy_reactor = proxy_reactor
 
-        self._endpoint_kwargs = {}
+        self._endpoint_kwargs: Dict[str, Any] = {}
         if connectTimeout is not None:
             self._endpoint_kwargs["timeout"] = connectTimeout
         if bindAddress is not None:
@@ -117,16 +131,12 @@ class ProxyAgent(_AgentBase):
             https_proxy = proxies["https"].encode() if "https" in proxies else None
             no_proxy = proxies["no"] if "no" in proxies else None
 
-        # Parse credentials from http and https proxy connection string if present
-        self.http_proxy_creds, http_proxy = parse_username_password(http_proxy)
-        self.https_proxy_creds, https_proxy = parse_username_password(https_proxy)
-
-        self.http_proxy_endpoint = _http_proxy_endpoint(
-            http_proxy, self.proxy_reactor, **self._endpoint_kwargs
+        self.http_proxy_endpoint, self.http_proxy_creds = _http_proxy_endpoint(
+            http_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs
         )
 
-        self.https_proxy_endpoint = _http_proxy_endpoint(
-            https_proxy, self.proxy_reactor, **self._endpoint_kwargs
+        self.https_proxy_endpoint, self.https_proxy_creds = _http_proxy_endpoint(
+            https_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs
         )
 
         self.no_proxy = no_proxy
@@ -134,7 +144,13 @@ class ProxyAgent(_AgentBase):
         self._policy_for_https = contextFactory
         self._reactor = reactor
 
-    def request(self, method, uri, headers=None, bodyProducer=None):
+    def request(
+        self,
+        method: bytes,
+        uri: bytes,
+        headers: Optional[Headers] = None,
+        bodyProducer: Optional[IBodyProducer] = None,
+    ) -> defer.Deferred:
         """
         Issue a request to the server indicated by the given uri.
 
@@ -146,16 +162,15 @@ class ProxyAgent(_AgentBase):
         See also: twisted.web.iweb.IAgent.request
 
         Args:
-            method (bytes): The request method to use, such as `GET`, `POST`, etc
+            method: The request method to use, such as `GET`, `POST`, etc
 
-            uri (bytes): The location of the resource to request.
+            uri: The location of the resource to request.
 
-            headers (Headers|None): Extra headers to send with the request
+            headers: Extra headers to send with the request
 
-            bodyProducer (IBodyProducer|None): An object which can generate bytes to
-                make up the body of this request (for example, the properly encoded
-                contents of a file for a file upload). Or, None if the request is to
-                have no body.
+            bodyProducer: An object which can generate bytes to make up the body of
+                this request (for example, the properly encoded contents of a file for
+                a file upload). Or, None if the request is to have no body.
 
         Returns:
             Deferred[IResponse]: completes when the header of the response has
@@ -253,70 +268,89 @@ class ProxyAgent(_AgentBase):
         )
 
 
-def _http_proxy_endpoint(proxy: Optional[bytes], reactor, **kwargs):
+def _http_proxy_endpoint(
+    proxy: Optional[bytes],
+    reactor: IReactorCore,
+    tls_options_factory: IPolicyForHTTPS,
+    **kwargs,
+) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]:
     """Parses an http proxy setting and returns an endpoint for the proxy
 
     Args:
-        proxy: the proxy setting in the form: [<username>:<password>@]<host>[:<port>]
-            Note that compared to other apps, this function currently lacks support
-            for specifying a protocol schema (i.e. protocol://...).
+        proxy: the proxy setting in the form: [scheme://][<username>:<password>@]<host>[:<port>]
+            This currently supports http:// and https:// proxies.
+            A hostname without scheme is assumed to be http.
 
         reactor: reactor to be used to connect to the proxy
 
+        tls_options_factory: the TLS options to use when connecting through a https proxy
+
         kwargs: other args to be passed to HostnameEndpoint
 
     Returns:
-        interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy,
-            or None
+        a tuple of
+            endpoint to use to connect to the proxy, or None
+            ProxyCredentials or if no credentials were found, or None
+
+    Raise:
+        ValueError if proxy has no hostname or unsupported scheme.
     """
     if proxy is None:
-        return None
+        return None, None
 
-    # Parse the connection string
-    host, port = parse_host_port(proxy, default_port=1080)
-    return HostnameEndpoint(reactor, host, port, **kwargs)
+    # Note: urlsplit/urlparse cannot be used here as that does not work (for Python
+    # 3.9+) on scheme-less proxies, e.g. host:port.
+    scheme, host, port, credentials = parse_proxy(proxy)
 
+    proxy_endpoint = HostnameEndpoint(reactor, host, port, **kwargs)
 
-def parse_username_password(proxy: bytes) -> Tuple[Optional[ProxyCredentials], bytes]:
-    """
-    Parses the username and password from a proxy declaration e.g
-    username:password@hostname:port.
+    if scheme == b"https":
+        tls_options = tls_options_factory.creatorForNetloc(host, port)
+        proxy_endpoint = wrapClientTLS(tls_options, proxy_endpoint)
 
-    Args:
-        proxy: The proxy connection string.
+    return proxy_endpoint, credentials
 
-    Returns
-        An instance of ProxyCredentials and the proxy connection string with any credentials
-        stripped, i.e u:p@host:port -> host:port. If no credentials were found, the
-        ProxyCredentials instance is replaced with None.
-    """
-    if proxy and b"@" in proxy:
-        # We use rsplit here as the password could contain an @ character
-        credentials, proxy_without_credentials = proxy.rsplit(b"@", 1)
-        return ProxyCredentials(credentials), proxy_without_credentials
 
-    return None, proxy
+def parse_proxy(
+    proxy: bytes, default_scheme: bytes = b"http", default_port: int = 1080
+) -> Tuple[bytes, bytes, int, Optional[ProxyCredentials]]:
+    """
+    Parse a proxy connection string.
 
+    Given a HTTP proxy URL, breaks it down into components and checks that it
+    has a hostname (otherwise it is not useful to us when trying to find a
+    proxy) and asserts that the URL has a scheme we support.
 
-def parse_host_port(hostport: bytes, default_port: int = None) -> Tuple[bytes, int]:
-    """
-    Parse the hostname and port from a proxy connection byte string.
 
     Args:
-        hostport: The proxy connection string. Must be in the form 'host[:port]'.
-        default_port: The default port to return if one is not found in `hostport`.
+        proxy: The proxy connection string. Must be in the form '[scheme://][<username>:<password>@]host[:port]'.
+        default_scheme: The default scheme to return if one is not found in `proxy`. Defaults to http
+        default_port: The default port to return if one is not found in `proxy`. Defaults to 1080
 
     Returns:
-        A tuple containing the hostname and port. Uses `default_port` if one was not found.
+        A tuple containing the scheme, hostname, port and ProxyCredentials.
+            If no credentials were found, the ProxyCredentials instance is replaced with None.
+
+    Raise:
+        ValueError if proxy has no hostname or unsupported scheme.
     """
-    if b":" in hostport:
-        host, port = hostport.rsplit(b":", 1)
-        try:
-            port = int(port)
-            return host, port
-        except ValueError:
-            # the thing after the : wasn't a valid port; presumably this is an
-            # IPv6 address.
-            pass
+    # First check if we have a scheme present
+    # Note: urlsplit/urlparse cannot be used (for Python # 3.9+) on scheme-less proxies, e.g. host:port.
+    if b"://" not in proxy:
+        proxy = b"".join([default_scheme, b"://", proxy])
+
+    url = urlparse(proxy)
+
+    if not url.hostname:
+        raise ValueError("Proxy URL did not contain a hostname! Please specify one.")
+
+    if url.scheme not in (b"http", b"https"):
+        raise ValueError(
+            f"Unknown proxy scheme {url.scheme!s}; only 'http' and 'https' is supported."
+        )
+
+    credentials = None
+    if url.username and url.password:
+        credentials = ProxyCredentials(b"".join([url.username, b":", url.password]))
 
-    return hostport, default_port
+    return url.scheme, url.hostname, url.port or default_port, credentials
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 04560fb589..732a1e6aeb 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -14,47 +14,86 @@
 
 """ This module contains base REST classes for constructing REST servlets. """
 import logging
-from typing import Dict, Iterable, List, Optional, overload
+from typing import Iterable, List, Mapping, Optional, Sequence, overload
 
 from typing_extensions import Literal
 
 from twisted.web.server import Request
 
 from synapse.api.errors import Codes, SynapseError
+from synapse.types import JsonDict
 from synapse.util import json_decoder
 
 logger = logging.getLogger(__name__)
 
 
-def parse_integer(request, name, default=None, required=False):
+@overload
+def parse_integer(request: Request, name: str, default: int) -> int:
+    ...
+
+
+@overload
+def parse_integer(request: Request, name: str, *, required: Literal[True]) -> int:
+    ...
+
+
+@overload
+def parse_integer(
+    request: Request, name: str, default: Optional[int] = None, required: bool = False
+) -> Optional[int]:
+    ...
+
+
+def parse_integer(
+    request: Request, name: str, default: Optional[int] = None, required: bool = False
+) -> Optional[int]:
     """Parse an integer parameter from the request string
 
     Args:
         request: the twisted HTTP request.
-        name (bytes/unicode): the name of the query parameter.
-        default (int|None): value to use if the parameter is absent, defaults
-            to None.
-        required (bool): whether to raise a 400 SynapseError if the
-            parameter is absent, defaults to False.
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent, defaults to None.
+        required: whether to raise a 400 SynapseError if the parameter is absent,
+            defaults to False.
 
     Returns:
-        int|None: An int value or the default.
+        An int value or the default.
 
     Raises:
         SynapseError: if the parameter is absent and required, or if the
             parameter is present and not an integer.
     """
-    return parse_integer_from_args(request.args, name, default, required)
+    args: Mapping[bytes, Sequence[bytes]] = request.args  # type: ignore
+    return parse_integer_from_args(args, name, default, required)
+
 
+def parse_integer_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    default: Optional[int] = None,
+    required: bool = False,
+) -> Optional[int]:
+    """Parse an integer parameter from the request string
+
+    Args:
+        args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent, defaults to None.
+        required: whether to raise a 400 SynapseError if the parameter is absent,
+            defaults to False.
 
-def parse_integer_from_args(args, name, default=None, required=False):
+    Returns:
+        An int value or the default.
 
-    if not isinstance(name, bytes):
-        name = name.encode("ascii")
+    Raises:
+        SynapseError: if the parameter is absent and required, or if the
+            parameter is present and not an integer.
+    """
+    name_bytes = name.encode("ascii")
 
-    if name in args:
+    if name_bytes in args:
         try:
-            return int(args[name][0])
+            return int(args[name_bytes][0])
         except Exception:
             message = "Query parameter %r must be an integer" % (name,)
             raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
@@ -66,36 +105,102 @@ def parse_integer_from_args(args, name, default=None, required=False):
             return default
 
 
-def parse_boolean(request, name, default=None, required=False):
+@overload
+def parse_boolean(request: Request, name: str, default: bool) -> bool:
+    ...
+
+
+@overload
+def parse_boolean(request: Request, name: str, *, required: Literal[True]) -> bool:
+    ...
+
+
+@overload
+def parse_boolean(
+    request: Request, name: str, default: Optional[bool] = None, required: bool = False
+) -> Optional[bool]:
+    ...
+
+
+def parse_boolean(
+    request: Request, name: str, default: Optional[bool] = None, required: bool = False
+) -> Optional[bool]:
     """Parse a boolean parameter from the request query string
 
     Args:
         request: the twisted HTTP request.
-        name (bytes/unicode): the name of the query parameter.
-        default (bool|None): value to use if the parameter is absent, defaults
-            to None.
-        required (bool): whether to raise a 400 SynapseError if the
-            parameter is absent, defaults to False.
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent, defaults to None.
+        required: whether to raise a 400 SynapseError if the parameter is absent,
+            defaults to False.
 
     Returns:
-        bool|None: A bool value or the default.
+        A bool value or the default.
 
     Raises:
         SynapseError: if the parameter is absent and required, or if the
             parameter is present and not one of "true" or "false".
     """
+    args: Mapping[bytes, Sequence[bytes]] = request.args  # type: ignore
+    return parse_boolean_from_args(args, name, default, required)
+
+
+@overload
+def parse_boolean_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    default: bool,
+) -> bool:
+    ...
+
 
-    return parse_boolean_from_args(request.args, name, default, required)
+@overload
+def parse_boolean_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    *,
+    required: Literal[True],
+) -> bool:
+    ...
 
 
-def parse_boolean_from_args(args, name, default=None, required=False):
+@overload
+def parse_boolean_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    default: Optional[bool] = None,
+    required: bool = False,
+) -> Optional[bool]:
+    ...
 
-    if not isinstance(name, bytes):
-        name = name.encode("ascii")
 
-    if name in args:
+def parse_boolean_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    default: Optional[bool] = None,
+    required: bool = False,
+) -> Optional[bool]:
+    """Parse a boolean parameter from the request query string
+
+    Args:
+        args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
+        name: the name of the query parameter.
+        default: value to use if the parameter is absent, defaults to None.
+        required: whether to raise a 400 SynapseError if the parameter is absent,
+            defaults to False.
+
+    Returns:
+        A bool value or the default.
+
+    Raises:
+        SynapseError: if the parameter is absent and required, or if the
+            parameter is present and not one of "true" or "false".
+    """
+    name_bytes = name.encode("ascii")
+
+    if name_bytes in args:
         try:
-            return {b"true": True, b"false": False}[args[name][0]]
+            return {b"true": True, b"false": False}[args[name_bytes][0]]
         except Exception:
             message = (
                 "Boolean query parameter %r must be one of ['true', 'false']"
@@ -111,7 +216,7 @@ def parse_boolean_from_args(args, name, default=None, required=False):
 
 @overload
 def parse_bytes_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[bytes] = None,
 ) -> Optional[bytes]:
@@ -120,7 +225,7 @@ def parse_bytes_from_args(
 
 @overload
 def parse_bytes_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Literal[None] = None,
     *,
@@ -131,7 +236,7 @@ def parse_bytes_from_args(
 
 @overload
 def parse_bytes_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[bytes] = None,
     required: bool = False,
@@ -140,7 +245,7 @@ def parse_bytes_from_args(
 
 
 def parse_bytes_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[bytes] = None,
     required: bool = False,
@@ -172,6 +277,42 @@ def parse_bytes_from_args(
     return default
 
 
+@overload
+def parse_string(
+    request: Request,
+    name: str,
+    default: str,
+    *,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: str = "ascii",
+) -> str:
+    ...
+
+
+@overload
+def parse_string(
+    request: Request,
+    name: str,
+    *,
+    required: Literal[True],
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: str = "ascii",
+) -> str:
+    ...
+
+
+@overload
+def parse_string(
+    request: Request,
+    name: str,
+    *,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: str = "ascii",
+) -> Optional[str]:
+    ...
+
+
 def parse_string(
     request: Request,
     name: str,
@@ -179,7 +320,7 @@ def parse_string(
     required: bool = False,
     allowed_values: Optional[Iterable[str]] = None,
     encoding: str = "ascii",
-):
+) -> Optional[str]:
     """
     Parse a string parameter from the request query string.
 
@@ -205,7 +346,7 @@ def parse_string(
             parameter is present, must be one of a list of allowed values and
             is not one of those allowed values.
     """
-    args: Dict[bytes, List[bytes]] = request.args  # type: ignore
+    args: Mapping[bytes, Sequence[bytes]] = request.args  # type: ignore
     return parse_string_from_args(
         args,
         name,
@@ -239,9 +380,8 @@ def _parse_string_value(
 
 @overload
 def parse_strings_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
-    default: Optional[List[str]] = None,
     *,
     allowed_values: Optional[Iterable[str]] = None,
     encoding: str = "ascii",
@@ -251,9 +391,20 @@ def parse_strings_from_args(
 
 @overload
 def parse_strings_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
+    name: str,
+    default: List[str],
+    *,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: str = "ascii",
+) -> List[str]:
+    ...
+
+
+@overload
+def parse_strings_from_args(
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
-    default: Optional[List[str]] = None,
     *,
     required: Literal[True],
     allowed_values: Optional[Iterable[str]] = None,
@@ -264,7 +415,7 @@ def parse_strings_from_args(
 
 @overload
 def parse_strings_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[List[str]] = None,
     *,
@@ -276,7 +427,7 @@ def parse_strings_from_args(
 
 
 def parse_strings_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[List[str]] = None,
     required: bool = False,
@@ -325,7 +476,7 @@ def parse_strings_from_args(
 
 @overload
 def parse_string_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[str] = None,
     *,
@@ -337,7 +488,7 @@ def parse_string_from_args(
 
 @overload
 def parse_string_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[str] = None,
     *,
@@ -350,7 +501,7 @@ def parse_string_from_args(
 
 @overload
 def parse_string_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[str] = None,
     required: bool = False,
@@ -361,7 +512,7 @@ def parse_string_from_args(
 
 
 def parse_string_from_args(
-    args: Dict[bytes, List[bytes]],
+    args: Mapping[bytes, Sequence[bytes]],
     name: str,
     default: Optional[str] = None,
     required: bool = False,
@@ -409,13 +560,14 @@ def parse_string_from_args(
     return strings[0]
 
 
-def parse_json_value_from_request(request, allow_empty_body=False):
+def parse_json_value_from_request(
+    request: Request, allow_empty_body: bool = False
+) -> Optional[JsonDict]:
     """Parse a JSON value from the body of a twisted HTTP request.
 
     Args:
         request: the twisted HTTP request.
-        allow_empty_body (bool): if True, an empty body will be accepted and
-            turned into None
+        allow_empty_body: if True, an empty body will be accepted and turned into None
 
     Returns:
         The JSON value.
@@ -424,7 +576,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
         SynapseError if the request body couldn't be decoded as JSON.
     """
     try:
-        content_bytes = request.content.read()
+        content_bytes = request.content.read()  # type: ignore
     except Exception:
         raise SynapseError(400, "Error reading JSON content.")
 
@@ -440,13 +592,15 @@ def parse_json_value_from_request(request, allow_empty_body=False):
     return content
 
 
-def parse_json_object_from_request(request, allow_empty_body=False):
+def parse_json_object_from_request(
+    request: Request, allow_empty_body: bool = False
+) -> JsonDict:
     """Parse a JSON object from the body of a twisted HTTP request.
 
     Args:
         request: the twisted HTTP request.
-        allow_empty_body (bool): if True, an empty body will be accepted and
-            turned into an empty dict.
+        allow_empty_body: if True, an empty body will be accepted and turned into
+            an empty dict.
 
     Raises:
         SynapseError if the request body couldn't be decoded as JSON or
@@ -457,14 +611,14 @@ def parse_json_object_from_request(request, allow_empty_body=False):
     if allow_empty_body and content is None:
         return {}
 
-    if type(content) != dict:
+    if not isinstance(content, dict):
         message = "Content must be a JSON object."
         raise SynapseError(400, message, errcode=Codes.BAD_JSON)
 
     return content
 
 
-def assert_params_in_dict(body, required):
+def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
     absent = []
     for k in required:
         if k not in body:
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 18ac507802..02e5ddd2ef 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -25,7 +25,7 @@ See doc/log_contexts.rst for details on how this works.
 import inspect
 import logging
 import threading
-import types
+import typing
 import warnings
 from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
 
@@ -745,7 +745,7 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
         # by synchronous exceptions, so let's turn them into Failures.
         return defer.fail()
 
-    if isinstance(res, types.CoroutineType):
+    if isinstance(res, typing.Coroutine):
         res = defer.ensureDeferred(res)
 
     # At this point we should have a Deferred, if not then f was a synchronous
diff --git a/synapse/logging/handlers.py b/synapse/logging/handlers.py
new file mode 100644
index 0000000000..a6c212f300
--- /dev/null
+++ b/synapse/logging/handlers.py
@@ -0,0 +1,88 @@
+import logging
+import time
+from logging import Handler, LogRecord
+from logging.handlers import MemoryHandler
+from threading import Thread
+from typing import Optional
+
+from twisted.internet.interfaces import IReactorCore
+
+
+class PeriodicallyFlushingMemoryHandler(MemoryHandler):
+    """
+    This is a subclass of MemoryHandler that additionally spawns a background
+    thread to periodically flush the buffer.
+
+    This prevents messages from being buffered for too long.
+
+    Additionally, all messages will be immediately flushed if the reactor has
+    not yet been started.
+    """
+
+    def __init__(
+        self,
+        capacity: int,
+        flushLevel: int = logging.ERROR,
+        target: Optional[Handler] = None,
+        flushOnClose: bool = True,
+        period: float = 5.0,
+        reactor: Optional[IReactorCore] = None,
+    ) -> None:
+        """
+        period: the period between automatic flushes
+
+        reactor: if specified, a custom reactor to use. If not specifies,
+            defaults to the globally-installed reactor.
+            Log entries will be flushed immediately until this reactor has
+            started.
+        """
+        super().__init__(capacity, flushLevel, target, flushOnClose)
+
+        self._flush_period: float = period
+        self._active: bool = True
+        self._reactor_started = False
+
+        self._flushing_thread: Thread = Thread(
+            name="PeriodicallyFlushingMemoryHandler flushing thread",
+            target=self._flush_periodically,
+        )
+        self._flushing_thread.start()
+
+        def on_reactor_running():
+            self._reactor_started = True
+
+        reactor_to_use: IReactorCore
+        if reactor is None:
+            from twisted.internet import reactor as global_reactor
+
+            reactor_to_use = global_reactor  # type: ignore[assignment]
+        else:
+            reactor_to_use = reactor
+
+        # call our hook when the reactor start up
+        reactor_to_use.callWhenRunning(on_reactor_running)
+
+    def shouldFlush(self, record: LogRecord) -> bool:
+        """
+        Before reactor start-up, log everything immediately.
+        Otherwise, fall back to original behaviour of waiting for the buffer to fill.
+        """
+
+        if self._reactor_started:
+            return super().shouldFlush(record)
+        else:
+            return True
+
+    def _flush_periodically(self):
+        """
+        Whilst this handler is active, flush the handler periodically.
+        """
+
+        while self._active:
+            # flush is thread-safe; it acquires and releases the lock internally
+            self.flush()
+            time.sleep(self._flush_period)
+
+    def close(self) -> None:
+        self._active = False
+        super().close()
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 1259fc2d90..473812b8e2 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -484,7 +484,7 @@ class ModuleApi:
     @defer.inlineCallbacks
     def get_state_events_in_room(
         self, room_id: str, types: Iterable[Tuple[str, Optional[str]]]
-    ) -> Generator[defer.Deferred, Any, defer.Deferred]:
+    ) -> Generator[defer.Deferred, Any, Iterable[EventBase]]:
         """Gets current state events for the given room.
 
         (This is exposed for compatibility with the old SpamCheckerApi. We should
diff --git a/synapse/notifier.py b/synapse/notifier.py
index c5fbebc17d..bbe337949a 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -111,8 +111,9 @@ class _NotifierUserStream:
         self.last_notified_token = current_token
         self.last_notified_ms = time_now_ms
 
-        with PreserveLoggingContext():
-            self.notify_deferred = ObservableDeferred(defer.Deferred())
+        self.notify_deferred: ObservableDeferred[StreamToken] = ObservableDeferred(
+            defer.Deferred()
+        )
 
     def notify(
         self,
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 7be5fe1e9b..941fb238b7 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar
 import bleach
 import jinja2
 
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RoomTypes
 from synapse.api.errors import StoreError
 from synapse.config.emailconfig import EmailSubjectConfig
 from synapse.events import EventBase
@@ -600,6 +600,22 @@ class Mailer:
                     "app": self.app_name,
                 }
 
+            # If the room is a space, it gets a slightly different topic.
+            create_event_id = room_state_ids.get(("m.room.create", ""))
+            if create_event_id:
+                create_event = await self.store.get_event(
+                    create_event_id, allow_none=True
+                )
+                if (
+                    create_event
+                    and create_event.content.get("room_type") == RoomTypes.SPACE
+                ):
+                    return self.email_subjects.invite_from_person_to_space % {
+                        "person": inviter_name,
+                        "space": room_name,
+                        "app": self.app_name,
+                    }
+
             return self.email_subjects.invite_from_person_to_room % {
                 "person": inviter_name,
                 "room": room_name,
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 9d4859798b..3fd2811713 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -285,7 +285,7 @@ class ReplicationDataHandler:
 
         # Create a new deferred that times out after N seconds, as we don't want
         # to wedge here forever.
-        deferred = Deferred()
+        deferred: "Deferred[None]" = Deferred()
         deferred = timeout_deferred(
             deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor
         )
@@ -393,6 +393,11 @@ class FederationSenderHandler:
             # we only want to send on receipts for our own users
             if not self._is_mine_id(receipt.user_id):
                 continue
+            if (
+                receipt.data.get("hidden", False)
+                and self._hs.config.experimental.msc2285_enabled
+            ):
+                continue
             receipt_info = ReadReceipt(
                 receipt.room_id,
                 receipt.receipt_type,
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 589e47fa47..eef76ab18a 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -62,6 +62,7 @@ class UsersRestServletV2(RestServlet):
     The parameter `name` can be used to filter by user id or display name.
     The parameter `guests` can be used to exclude guest users.
     The parameter `deactivated` can be used to include deactivated users.
+    The parameter `order_by` can be used to order the result.
     """
 
     def __init__(self, hs: "HomeServer"):
@@ -90,8 +91,8 @@ class UsersRestServletV2(RestServlet):
                 errcode=Codes.INVALID_PARAM,
             )
 
-        user_id = parse_string(request, "user_id", default=None)
-        name = parse_string(request, "name", default=None)
+        user_id = parse_string(request, "user_id")
+        name = parse_string(request, "name")
         guests = parse_boolean(request, "guests", default=True)
         deactivated = parse_boolean(request, "deactivated", default=False)
 
@@ -108,6 +109,7 @@ class UsersRestServletV2(RestServlet):
                 UserSortOrder.USER_TYPE.value,
                 UserSortOrder.AVATAR_URL.value,
                 UserSortOrder.SHADOW_BANNED.value,
+                UserSortOrder.CREATION_TS.value,
             ),
         )
 
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 31a1193cd3..502a917588 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -413,7 +413,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
         assert_params_in_dict(body, ["state_events_at_start", "events"])
 
         prev_events_from_query = parse_strings_from_args(request.args, "prev_event")
-        chunk_id_from_query = parse_string(request, "chunk_id", default=None)
+        chunk_id_from_query = parse_string(request, "chunk_id")
 
         if prev_events_from_query is None:
             raise SynapseError(
@@ -504,7 +504,6 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
 
         events_to_create = body["events"]
 
-        prev_event_ids = prev_events_from_query
         inherited_depth = await self._inherit_depth_from_prev_ids(
             prev_events_from_query
         )
@@ -516,6 +515,10 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
         chunk_id_to_connect_to = chunk_id_from_query
         base_insertion_event = None
         if chunk_id_from_query:
+            #  All but the first base insertion event should point at a fake
+            #  event, which causes the HS to ask for the state at the start of
+            #  the chunk later.
+            prev_event_ids = [fake_prev_event_id]
             # TODO: Verify the chunk_id_from_query corresponds to an insertion event
             pass
         # Otherwise, create an insertion event to act as a starting point.
@@ -526,6 +529,8 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
         # an insertion event), in which case we just create a new insertion event
         # that can then get pointed to by a "marker" event later.
         else:
+            prev_event_ids = prev_events_from_query
+
             base_insertion_event_dict = self._create_insertion_event_dict(
                 sender=requester.user.to_string(),
                 room_id=room_id,
@@ -553,9 +558,18 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
             ]
 
         # Connect this current chunk to the insertion event from the previous chunk
-        last_event_in_chunk["content"][
-            EventContentFields.MSC2716_CHUNK_ID
-        ] = chunk_id_to_connect_to
+        chunk_event = {
+            "type": EventTypes.MSC2716_CHUNK,
+            "sender": requester.user.to_string(),
+            "room_id": room_id,
+            "content": {EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to},
+            # Since the chunk event is put at the end of the chunk,
+            # where the newest-in-time event is, copy the origin_server_ts from
+            # the last event we're inserting
+            "origin_server_ts": last_event_in_chunk["origin_server_ts"],
+        }
+        # Add the chunk event to the end of the chunk (newest-in-time)
+        events_to_create.append(chunk_event)
 
         # Add an "insertion" event to the start of each chunk (next to the oldest-in-time
         # event in the chunk) so the next chunk can be connected to this one.
@@ -567,7 +581,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
             # the first event we're inserting
             origin_server_ts=events_to_create[0]["origin_server_ts"],
         )
-        # Prepend the insertion event to the start of the chunk
+        # Prepend the insertion event to the start of the chunk (oldest-in-time)
         events_to_create = [insertion_event] + events_to_create
 
         event_ids = []
@@ -726,7 +740,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
         self.auth = hs.get_auth()
 
     async def on_GET(self, request):
-        server = parse_string(request, "server", default=None)
+        server = parse_string(request, "server")
 
         try:
             await self.auth.get_user_by_req(request, allow_guest=True)
@@ -745,8 +759,8 @@ class PublicRoomListRestServlet(TransactionRestServlet):
             if server:
                 raise e
 
-        limit = parse_integer(request, "limit", 0)
-        since_token = parse_string(request, "since", None)
+        limit: Optional[int] = parse_integer(request, "limit", 0)
+        since_token = parse_string(request, "since")
 
         if limit == 0:
             # zero is a special value which corresponds to no limit.
@@ -780,7 +794,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
     async def on_POST(self, request):
         await self.auth.get_user_by_req(request, allow_guest=True)
 
-        server = parse_string(request, "server", default=None)
+        server = parse_string(request, "server")
         content = parse_json_object_from_request(request)
 
         limit: Optional[int] = int(content.get("limit", 100))
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 085561d3e9..fb5ad2906e 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -884,7 +884,14 @@ class WhoamiRestServlet(RestServlet):
     async def on_GET(self, request):
         requester = await self.auth.get_user_by_req(request)
 
-        return 200, {"user_id": requester.user.to_string()}
+        response = {"user_id": requester.user.to_string()}
+
+        # Appservices and similar accounts do not have device IDs
+        # that we can report on, so exclude them for compliance.
+        if requester.device_id is not None:
+            response["device_id"] = requester.device_id
+
+        return 200, response
 
 
 def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py
index 6a24021484..88e3aac797 100644
--- a/synapse/rest/client/v2_alpha/capabilities.py
+++ b/synapse/rest/client/v2_alpha/capabilities.py
@@ -14,7 +14,7 @@
 import logging
 from typing import TYPE_CHECKING, Tuple
 
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES
 from synapse.http.servlet import RestServlet
 from synapse.http.site import SynapseRequest
 from synapse.types import JsonDict
@@ -55,6 +55,12 @@ class CapabilitiesRestServlet(RestServlet):
                 "m.change_password": {"enabled": change_password},
             }
         }
+
+        if self.config.experimental.msc3244_enabled:
+            response["capabilities"]["m.room_versions"][
+                "org.matrix.msc3244.room_capabilities"
+            ] = MSC3244_CAPABILITIES
+
         return 200, response
 
 
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 33cf8de186..d0d9d30d40 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -194,7 +194,7 @@ class KeyChangesServlet(RestServlet):
     async def on_GET(self, request):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
-        from_token_string = parse_string(request, "from")
+        from_token_string = parse_string(request, "from", required=True)
         set_tag("from", from_token_string)
 
         # We want to enforce they do pass us one, but we ignore it and return
diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py
index 5988fa47e5..027f8b81fa 100644
--- a/synapse/rest/client/v2_alpha/read_marker.py
+++ b/synapse/rest/client/v2_alpha/read_marker.py
@@ -14,6 +14,8 @@
 
 import logging
 
+from synapse.api.constants import ReadReceiptEventFields
+from synapse.api.errors import Codes, SynapseError
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 
 from ._base import client_patterns
@@ -37,14 +39,24 @@ class ReadMarkerRestServlet(RestServlet):
         await self.presence_handler.bump_presence_active_time(requester.user)
 
         body = parse_json_object_from_request(request)
-
         read_event_id = body.get("m.read", None)
+        hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False)
+
+        if not isinstance(hidden, bool):
+            raise SynapseError(
+                400,
+                "Param %s must be a boolean, if given"
+                % ReadReceiptEventFields.MSC2285_HIDDEN,
+                Codes.BAD_JSON,
+            )
+
         if read_event_id:
             await self.receipts_handler.received_client_receipt(
                 room_id,
                 "m.read",
                 user_id=requester.user.to_string(),
                 event_id=read_event_id,
+                hidden=hidden,
             )
 
         read_marker_event_id = body.get("m.fully_read", None)
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index 8cf4aebdbe..4b98979b47 100644
--- a/synapse/rest/client/v2_alpha/receipts.py
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -14,8 +14,9 @@
 
 import logging
 
-from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet
+from synapse.api.constants import ReadReceiptEventFields
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
 
 from ._base import client_patterns
 
@@ -42,10 +43,25 @@ class ReceiptRestServlet(RestServlet):
         if receipt_type != "m.read":
             raise SynapseError(400, "Receipt type must be 'm.read'")
 
+        body = parse_json_object_from_request(request)
+        hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False)
+
+        if not isinstance(hidden, bool):
+            raise SynapseError(
+                400,
+                "Param %s must be a boolean, if given"
+                % ReadReceiptEventFields.MSC2285_HIDDEN,
+                Codes.BAD_JSON,
+            )
+
         await self.presence_handler.bump_presence_active_time(requester.user)
 
         await self.receipts_handler.received_client_receipt(
-            room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id
+            room_id,
+            receipt_type,
+            user_id=requester.user.to_string(),
+            event_id=event_id,
+            hidden=hidden,
         )
 
         return 200, {}
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index c7da6759db..0821cd285f 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -158,19 +158,21 @@ class RelationPaginationServlet(RestServlet):
         event = await self.event_handler.get_event(requester.user, room_id, parent_id)
 
         limit = parse_integer(request, "limit", default=5)
-        from_token = parse_string(request, "from")
-        to_token = parse_string(request, "to")
+        from_token_str = parse_string(request, "from")
+        to_token_str = parse_string(request, "to")
 
         if event.internal_metadata.is_redacted():
             # If the event is redacted, return an empty list of relations
             pagination_chunk = PaginationChunk(chunk=[])
         else:
             # Return the relations
-            if from_token:
-                from_token = RelationPaginationToken.from_string(from_token)
+            from_token = None
+            if from_token_str:
+                from_token = RelationPaginationToken.from_string(from_token_str)
 
-            if to_token:
-                to_token = RelationPaginationToken.from_string(to_token)
+            to_token = None
+            if to_token_str:
+                to_token = RelationPaginationToken.from_string(to_token_str)
 
             pagination_chunk = await self.store.get_relations_for_event(
                 event_id=parent_id,
@@ -256,19 +258,21 @@ class RelationAggregationPaginationServlet(RestServlet):
             raise SynapseError(400, "Relation type must be 'annotation'")
 
         limit = parse_integer(request, "limit", default=5)
-        from_token = parse_string(request, "from")
-        to_token = parse_string(request, "to")
+        from_token_str = parse_string(request, "from")
+        to_token_str = parse_string(request, "to")
 
         if event.internal_metadata.is_redacted():
             # If the event is redacted, return an empty list of relations
             pagination_chunk = PaginationChunk(chunk=[])
         else:
             # Return the relations
-            if from_token:
-                from_token = AggregationPaginationToken.from_string(from_token)
+            from_token = None
+            if from_token_str:
+                from_token = AggregationPaginationToken.from_string(from_token_str)
 
-            if to_token:
-                to_token = AggregationPaginationToken.from_string(to_token)
+            to_token = None
+            if to_token_str:
+                to_token = AggregationPaginationToken.from_string(to_token_str)
 
             pagination_chunk = await self.store.get_aggregation_groups_for_event(
                 event_id=parent_id,
@@ -336,14 +340,16 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
             raise SynapseError(400, "Relation type must be 'annotation'")
 
         limit = parse_integer(request, "limit", default=5)
-        from_token = parse_string(request, "from")
-        to_token = parse_string(request, "to")
+        from_token_str = parse_string(request, "from")
+        to_token_str = parse_string(request, "to")
 
-        if from_token:
-            from_token = RelationPaginationToken.from_string(from_token)
+        from_token = None
+        if from_token_str:
+            from_token = RelationPaginationToken.from_string(from_token_str)
 
-        if to_token:
-            to_token = RelationPaginationToken.from_string(to_token)
+        to_token = None
+        if to_token_str:
+            to_token = RelationPaginationToken.from_string(to_token_str)
 
         result = await self.store.get_relations_for_event(
             event_id=parent_id,
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 32e8500795..e321668698 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -112,7 +112,7 @@ class SyncRestServlet(RestServlet):
             default="online",
             allowed_values=self.ALLOWED_PRESENCE,
         )
-        filter_id = parse_string(request, "filter", default=None)
+        filter_id = parse_string(request, "filter")
         full_state = parse_boolean(request, "full_state", default=False)
 
         logger.debug(
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 4582c274c7..fa2e4e9cba 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -82,6 +82,8 @@ class VersionsRestServlet(RestServlet):
                     "io.element.e2ee_forced.trusted_private": self.e2ee_forced_trusted_private,
                     # Supports the busy presence state described in MSC3026.
                     "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled,
+                    # Supports receiving hidden read receipts as per MSC2285
+                    "org.matrix.msc2285": self.config.experimental.msc2285_enabled,
                 },
             },
         )
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 4282e2b228..11f7320832 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -112,7 +112,7 @@ class ConsentResource(DirectServeHtmlResource):
             request (twisted.web.http.Request):
         """
         version = parse_string(request, "v", default=self._default_consent_version)
-        username = parse_string(request, "u", required=False, default="")
+        username = parse_string(request, "u", default="")
         userhmac = None
         has_consented = False
         public_version = username == ""
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index cd2468f9c5..d6d938953e 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -49,6 +49,8 @@ class DownloadResource(DirectServeJsonResource):
             b" media-src 'self';"
             b" object-src 'self';",
         )
+        # Limited non-standard form of CSP for IE11
+        request.setHeader(b"X-Content-Security-Policy", b"sandbox;")
         request.setHeader(
             b"Referrer-Policy",
             b"no-referrer",
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 8e7fead3a2..0f051d4041 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -58,9 +58,11 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-_charset_match = re.compile(br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9-]+)"?', flags=re.I)
+_charset_match = re.compile(
+    br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I
+)
 _xml_encoding_match = re.compile(
-    br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9-]+)"', flags=re.I
+    br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I
 )
 _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
 
@@ -186,15 +188,11 @@ class PreviewUrlResource(DirectServeJsonResource):
         respond_with_json(request, 200, {}, send_cors=True)
 
     async def _async_render_GET(self, request: SynapseRequest) -> None:
-        # This will always be set by the time Twisted calls us.
-        assert request.args is not None
-
         # XXX: if get_user_by_req fails, what should we do in an async render?
         requester = await self.auth.get_user_by_req(request)
-        url = parse_string(request, "url")
-        if b"ts" in request.args:
-            ts = parse_integer(request, "ts")
-        else:
+        url = parse_string(request, "url", required=True)
+        ts = parse_integer(request, "ts")
+        if ts is None:
             ts = self.clock.time_msec()
 
         # XXX: we could move this into _do_preview if we wanted.
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 6223daf522..463ce58dae 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,6 +16,7 @@ import heapq
 import logging
 from collections import defaultdict, namedtuple
 from typing import (
+    TYPE_CHECKING,
     Any,
     Awaitable,
     Callable,
@@ -52,6 +53,10 @@ from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.metrics import Measure, measure_func
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+    from synapse.storage.databases.main import DataStore
+
 logger = logging.getLogger(__name__)
 metrics_logger = logging.getLogger("synapse.state.metrics")
 
@@ -74,7 +79,7 @@ _NEXT_STATE_ID = 1
 POWER_KEY = (EventTypes.PowerLevels, "")
 
 
-def _gen_state_id():
+def _gen_state_id() -> str:
     global _NEXT_STATE_ID
     s = "X%d" % (_NEXT_STATE_ID,)
     _NEXT_STATE_ID += 1
@@ -109,7 +114,7 @@ class _StateCacheEntry:
         # `state_id` is either a state_group (and so an int) or a string. This
         # ensures we don't accidentally persist a state_id as a stateg_group
         if state_group:
-            self.state_id = state_group
+            self.state_id: Union[str, int] = state_group
         else:
             self.state_id = _gen_state_id()
 
@@ -122,7 +127,7 @@ class StateHandler:
     where necessary
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.state_store = hs.get_storage().state
@@ -507,7 +512,7 @@ class StateResolutionHandler:
     be storage-independent.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
 
         self.resolve_linearizer = Linearizer(name="state_resolve_lock")
@@ -636,16 +641,20 @@ class StateResolutionHandler:
         """
         try:
             with Measure(self.clock, "state._resolve_events") as m:
-                v = KNOWN_ROOM_VERSIONS[room_version]
-                if v.state_res == StateResolutionVersions.V1:
+                room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+                if room_version_obj.state_res == StateResolutionVersions.V1:
                     return await v1.resolve_events_with_store(
-                        room_id, state_sets, event_map, state_res_store.get_events
+                        room_id,
+                        room_version_obj,
+                        state_sets,
+                        event_map,
+                        state_res_store.get_events,
                     )
                 else:
                     return await v2.resolve_events_with_store(
                         self.clock,
                         room_id,
-                        room_version,
+                        room_version_obj,
                         state_sets,
                         event_map,
                         state_res_store,
@@ -653,13 +662,15 @@ class StateResolutionHandler:
         finally:
             self._record_state_res_metrics(room_id, m.get_resource_usage())
 
-    def _record_state_res_metrics(self, room_id: str, rusage: ContextResourceUsage):
+    def _record_state_res_metrics(
+        self, room_id: str, rusage: ContextResourceUsage
+    ) -> None:
         room_metrics = self._state_res_metrics[room_id]
         room_metrics.cpu_time += rusage.ru_utime + rusage.ru_stime
         room_metrics.db_time += rusage.db_txn_duration_sec
         room_metrics.db_events += rusage.evt_db_fetch_count
 
-    def _report_metrics(self):
+    def _report_metrics(self) -> None:
         if not self._state_res_metrics:
             # no state res has happened since the last iteration: don't bother logging.
             return
@@ -769,16 +780,13 @@ def _make_state_cache_entry(
     )
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class StateResolutionStore:
     """Interface that allows state resolution algorithms to access the database
     in well defined way.
-
-    Args:
-        store (DataStore)
     """
 
-    store = attr.ib()
+    store: "DataStore"
 
     def get_events(
         self, event_ids: Iterable[str], allow_rejected: bool = False
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 267193cedf..92336d7cc8 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -29,7 +29,7 @@ from typing import (
 from synapse import event_auth
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
-from synapse.api.room_versions import RoomVersions
+from synapse.api.room_versions import RoomVersion, RoomVersions
 from synapse.events import EventBase
 from synapse.types import MutableStateMap, StateMap
 
@@ -41,6 +41,7 @@ POWER_KEY = (EventTypes.PowerLevels, "")
 
 async def resolve_events_with_store(
     room_id: str,
+    room_version: RoomVersion,
     state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
     state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
@@ -104,7 +105,7 @@ async def resolve_events_with_store(
     # get the ids of the auth events which allow us to authenticate the
     # conflicted state, picking only from the unconflicting state.
     auth_events = _create_auth_events_from_maps(
-        unconflicted_state, conflicted_state, state_map
+        room_version, unconflicted_state, conflicted_state, state_map
     )
 
     new_needed_events = set(auth_events.values())
@@ -132,7 +133,7 @@ async def resolve_events_with_store(
     state_map.update(state_map_new)
 
     return _resolve_with_state(
-        unconflicted_state, conflicted_state, auth_events, state_map
+        room_version, unconflicted_state, conflicted_state, auth_events, state_map
     )
 
 
@@ -187,6 +188,7 @@ def _seperate(
 
 
 def _create_auth_events_from_maps(
+    room_version: RoomVersion,
     unconflicted_state: StateMap[str],
     conflicted_state: StateMap[Set[str]],
     state_map: Dict[str, EventBase],
@@ -194,6 +196,7 @@ def _create_auth_events_from_maps(
     """
 
     Args:
+        room_version: The room version.
         unconflicted_state: The unconflicted state map.
         conflicted_state: The conflicted state map.
         state_map:
@@ -205,7 +208,9 @@ def _create_auth_events_from_maps(
     for event_ids in conflicted_state.values():
         for event_id in event_ids:
             if event_id in state_map:
-                keys = event_auth.auth_types_for_event(state_map[event_id])
+                keys = event_auth.auth_types_for_event(
+                    room_version, state_map[event_id]
+                )
                 for key in keys:
                     if key not in auth_events:
                         auth_event_id = unconflicted_state.get(key, None)
@@ -215,6 +220,7 @@ def _create_auth_events_from_maps(
 
 
 def _resolve_with_state(
+    room_version: RoomVersion,
     unconflicted_state_ids: MutableStateMap[str],
     conflicted_state_ids: StateMap[Set[str]],
     auth_event_ids: StateMap[str],
@@ -235,7 +241,9 @@ def _resolve_with_state(
     }
 
     try:
-        resolved_state = _resolve_state_events(conflicted_state, auth_events)
+        resolved_state = _resolve_state_events(
+            room_version, conflicted_state, auth_events
+        )
     except Exception:
         logger.exception("Failed to resolve state")
         raise
@@ -248,7 +256,9 @@ def _resolve_with_state(
 
 
 def _resolve_state_events(
-    conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase]
+    room_version: RoomVersion,
+    conflicted_state: StateMap[List[EventBase]],
+    auth_events: MutableStateMap[EventBase],
 ) -> StateMap[EventBase]:
     """This is where we actually decide which of the conflicted state to
     use.
@@ -263,21 +273,27 @@ def _resolve_state_events(
     if POWER_KEY in conflicted_state:
         events = conflicted_state[POWER_KEY]
         logger.debug("Resolving conflicted power levels %r", events)
-        resolved_state[POWER_KEY] = _resolve_auth_events(events, auth_events)
+        resolved_state[POWER_KEY] = _resolve_auth_events(
+            room_version, events, auth_events
+        )
 
     auth_events.update(resolved_state)
 
     for key, events in conflicted_state.items():
         if key[0] == EventTypes.JoinRules:
             logger.debug("Resolving conflicted join rules %r", events)
-            resolved_state[key] = _resolve_auth_events(events, auth_events)
+            resolved_state[key] = _resolve_auth_events(
+                room_version, events, auth_events
+            )
 
     auth_events.update(resolved_state)
 
     for key, events in conflicted_state.items():
         if key[0] == EventTypes.Member:
             logger.debug("Resolving conflicted member lists %r", events)
-            resolved_state[key] = _resolve_auth_events(events, auth_events)
+            resolved_state[key] = _resolve_auth_events(
+                room_version, events, auth_events
+            )
 
     auth_events.update(resolved_state)
 
@@ -290,12 +306,14 @@ def _resolve_state_events(
 
 
 def _resolve_auth_events(
-    events: List[EventBase], auth_events: StateMap[EventBase]
+    room_version: RoomVersion, events: List[EventBase], auth_events: StateMap[EventBase]
 ) -> EventBase:
     reverse = list(reversed(_ordered_events(events)))
 
     auth_keys = {
-        key for event in events for key in event_auth.auth_types_for_event(event)
+        key
+        for event in events
+        for key in event_auth.auth_types_for_event(room_version, event)
     }
 
     new_auth_events = {}
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index e66e6571c8..7b1e8361de 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -36,7 +36,7 @@ import synapse.state
 from synapse import event_auth
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.api.room_versions import RoomVersion
 from synapse.events import EventBase
 from synapse.types import MutableStateMap, StateMap
 from synapse.util import Clock
@@ -53,7 +53,7 @@ _AWAIT_AFTER_ITERATIONS = 100
 async def resolve_events_with_store(
     clock: Clock,
     room_id: str,
-    room_version: str,
+    room_version: RoomVersion,
     state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
     state_res_store: "synapse.state.StateResolutionStore",
@@ -497,7 +497,7 @@ async def _reverse_topological_power_sort(
 async def _iterative_auth_checks(
     clock: Clock,
     room_id: str,
-    room_version: str,
+    room_version: RoomVersion,
     event_ids: List[str],
     base_state: StateMap[str],
     event_map: Dict[str, EventBase],
@@ -519,7 +519,6 @@ async def _iterative_auth_checks(
         Returns the final updated state
     """
     resolved_state = dict(base_state)
-    room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
 
     for idx, event_id in enumerate(event_ids, start=1):
         event = event_map[event_id]
@@ -538,7 +537,7 @@ async def _iterative_auth_checks(
                 if ev.rejected_reason is None:
                     auth_events[(ev.type, ev.state_key)] = ev
 
-        for key in event_auth.auth_types_for_event(event):
+        for key in event_auth.auth_types_for_event(room_version, event):
             if key in resolved_state:
                 ev_id = resolved_state[key]
                 ev = await _get_event(room_id, ev_id, event_map, state_res_store)
@@ -548,7 +547,7 @@ async def _iterative_auth_checks(
 
         try:
             event_auth.check(
-                room_version_obj,
+                room_version,
                 event,
                 auth_events,
                 do_sig_check=False,
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index ccf9ac51ef..4d4643619f 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -832,31 +832,16 @@ class DatabasePool:
         self,
         table: str,
         values: Dict[str, Any],
-        or_ignore: bool = False,
         desc: str = "simple_insert",
-    ) -> bool:
+    ) -> None:
         """Executes an INSERT query on the named table.
 
         Args:
             table: string giving the table name
             values: dict of new column names and values for them
-            or_ignore: bool stating whether an exception should be raised
-                when a conflicting row already exists. If True, False will be
-                returned by the function instead
             desc: description of the transaction, for logging and metrics
-
-        Returns:
-             Whether the row was inserted or not. Only useful when `or_ignore` is True
         """
-        try:
-            await self.runInteraction(desc, self.simple_insert_txn, table, values)
-        except self.engine.module.IntegrityError:
-            # We have to do or_ignore flag at this layer, since we can't reuse
-            # a cursor after we receive an error from the db.
-            if not or_ignore:
-                raise
-            return False
-        return True
+        await self.runInteraction(desc, self.simple_insert_txn, table, values)
 
     @staticmethod
     def simple_insert_txn(
@@ -930,7 +915,7 @@ class DatabasePool:
         insertion_values: Optional[Dict[str, Any]] = None,
         desc: str = "simple_upsert",
         lock: bool = True,
-    ) -> Optional[bool]:
+    ) -> bool:
         """
 
         `lock` should generally be set to True (the default), but can be set
@@ -951,8 +936,8 @@ class DatabasePool:
             desc: description of the transaction, for logging and metrics
             lock: True to lock the table when doing the upsert.
         Returns:
-            Native upserts always return None. Emulated upserts return True if a
-            new entry was created, False if an existing one was updated.
+            Returns True if a row was inserted or updated (i.e. if `values` is
+            not empty then this always returns True)
         """
         insertion_values = insertion_values or {}
 
@@ -995,7 +980,7 @@ class DatabasePool:
         values: Dict[str, Any],
         insertion_values: Optional[Dict[str, Any]] = None,
         lock: bool = True,
-    ) -> Optional[bool]:
+    ) -> bool:
         """
         Pick the UPSERT method which works best on the platform. Either the
         native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
@@ -1008,16 +993,15 @@ class DatabasePool:
             insertion_values: additional key/values to use only when inserting
             lock: True to lock the table when doing the upsert.
         Returns:
-            Native upserts always return None. Emulated upserts return True if a
-            new entry was created, False if an existing one was updated.
+            Returns True if a row was inserted or updated (i.e. if `values` is
+            not empty then this always returns True)
         """
         insertion_values = insertion_values or {}
 
         if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
-            self.simple_upsert_txn_native_upsert(
+            return self.simple_upsert_txn_native_upsert(
                 txn, table, keyvalues, values, insertion_values=insertion_values
             )
-            return None
         else:
             return self.simple_upsert_txn_emulated(
                 txn,
@@ -1045,8 +1029,8 @@ class DatabasePool:
             insertion_values: additional key/values to use only when inserting
             lock: True to lock the table when doing the upsert.
         Returns:
-            Returns True if a new entry was created, False if an existing
-            one was updated.
+            Returns True if a row was inserted or updated (i.e. if `values` is
+            not empty then this always returns True)
         """
         insertion_values = insertion_values or {}
 
@@ -1086,8 +1070,7 @@ class DatabasePool:
 
             txn.execute(sql, sqlargs)
             if txn.rowcount > 0:
-                # successfully updated at least one row.
-                return False
+                return True
 
         # We didn't find any existing rows, so insert a new one
         allvalues: Dict[str, Any] = {}
@@ -1111,15 +1094,19 @@ class DatabasePool:
         keyvalues: Dict[str, Any],
         values: Dict[str, Any],
         insertion_values: Optional[Dict[str, Any]] = None,
-    ) -> None:
+    ) -> bool:
         """
-        Use the native UPSERT functionality in recent PostgreSQL versions.
+        Use the native UPSERT functionality in PostgreSQL.
 
         Args:
             table: The table to upsert into
             keyvalues: The unique key tables and their new values
             values: The nonunique columns and their new values
             insertion_values: additional key/values to use only when inserting
+
+        Returns:
+            Returns True if a row was inserted or updated (i.e. if `values` is
+            not empty then this always returns True)
         """
         allvalues: Dict[str, Any] = {}
         allvalues.update(keyvalues)
@@ -1140,6 +1127,8 @@ class DatabasePool:
         )
         txn.execute(sql, list(allvalues.values()))
 
+        return bool(txn.rowcount)
+
     async def simple_upsert_many(
         self,
         table: str,
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index a3fddea042..8d9f07111d 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -249,7 +249,7 @@ class DataStore(
         name: Optional[str] = None,
         guests: bool = True,
         deactivated: bool = False,
-        order_by: UserSortOrder = UserSortOrder.USER_ID.value,
+        order_by: str = UserSortOrder.USER_ID.value,
         direction: str = "f",
     ) -> Tuple[List[JsonDict], int]:
         """Function to retrieve a paginated list of users from
@@ -297,27 +297,22 @@ class DataStore(
 
             where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
 
-            sql_base = """
+            sql_base = f"""
                 FROM users as u
                 LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
-                {}
-                """.format(
-                where_clause
-            )
+                {where_clause}
+                """
             sql = "SELECT COUNT(*) as total_users " + sql_base
             txn.execute(sql, args)
             count = txn.fetchone()[0]
 
-            sql = """
-                SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url
+            sql = f"""
+                SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
+                displayname, avatar_url, creation_ts * 1000 as creation_ts
                 {sql_base}
                 ORDER BY {order_by_column} {order}, u.name ASC
                 LIMIT ? OFFSET ?
-            """.format(
-                sql_base=sql_base,
-                order_by_column=order_by_column,
-                order=order,
-            )
+            """
             args += [limit, start]
             txn.execute(sql, args)
             users = self.db_pool.cursor_to_dict(txn)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 18f07d96dc..3816a0ca53 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1078,16 +1078,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             return False
 
         try:
-            inserted = await self.db_pool.simple_insert(
+            inserted = await self.db_pool.simple_upsert(
                 "devices",
-                values={
+                keyvalues={
                     "user_id": user_id,
                     "device_id": device_id,
+                },
+                values={},
+                insertion_values={
                     "display_name": initial_device_display_name,
                     "hidden": False,
                 },
                 desc="store_device",
-                or_ignore=True,
             )
             if not inserted:
                 # if the device already exists, check if it's a real device, or
@@ -1099,6 +1101,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 )
                 if hidden:
                     raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
+
             self.device_id_exists_cache.set(key, True)
             return inserted
         except StoreError:
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index d39368c20e..547e43ab98 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -936,15 +936,46 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         # We want to make sure that we do a breadth-first, "depth" ordered
         # search.
 
-        query = (
-            "SELECT depth, prev_event_id FROM event_edges"
-            " INNER JOIN events"
-            " ON prev_event_id = events.event_id"
-            " WHERE event_edges.event_id = ?"
-            " AND event_edges.is_state = ?"
-            " LIMIT ?"
-        )
+        # Look for the prev_event_id connected to the given event_id
+        query = """
+            SELECT depth, prev_event_id FROM event_edges
+            /* Get the depth of the prev_event_id from the events table */
+            INNER JOIN events
+            ON prev_event_id = events.event_id
+            /* Find an event which matches the given event_id */
+            WHERE event_edges.event_id = ?
+            AND event_edges.is_state = ?
+            LIMIT ?
+        """
+
+        # Look for the "insertion" events connected to the given event_id
+        connected_insertion_event_query = """
+            SELECT e.depth, i.event_id FROM insertion_event_edges AS i
+            /* Get the depth of the insertion event from the events table */
+            INNER JOIN events AS e USING (event_id)
+            /* Find an insertion event which points via prev_events to the given event_id */
+            WHERE i.insertion_prev_event_id = ?
+            LIMIT ?
+        """
+
+        # Find any chunk connections of a given insertion event
+        chunk_connection_query = """
+            SELECT e.depth, c.event_id FROM insertion_events AS i
+            /* Find the chunk that connects to the given insertion event */
+            INNER JOIN chunk_events AS c
+            ON i.next_chunk_id = c.chunk_id
+            /* Get the depth of the chunk start event from the events table */
+            INNER JOIN events AS e USING (event_id)
+            /* Find an insertion event which matches the given event_id */
+            WHERE i.event_id = ?
+            LIMIT ?
+        """
 
+        # In a PriorityQueue, the lowest valued entries are retrieved first.
+        # We're using depth as the priority in the queue.
+        # Depth is lowest at the oldest-in-time message and highest and
+        # newest-in-time message. We add events to the queue with a negative depth so that
+        # we process the newest-in-time messages first going backwards in time.
         queue = PriorityQueue()
 
         for event_id in event_list:
@@ -970,9 +1001,48 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
             event_results.add(event_id)
 
+            # Try and find any potential historical chunks of message history.
+            #
+            # First we look for an insertion event connected to the current
+            # event (by prev_event). If we find any, we need to go and try to
+            # find any chunk events connected to the insertion event (by
+            # chunk_id). If we find any, we'll add them to the queue and
+            # navigate up the DAG like normal in the next iteration of the loop.
+            txn.execute(
+                connected_insertion_event_query, (event_id, limit - len(event_results))
+            )
+            connected_insertion_event_id_results = txn.fetchall()
+            logger.debug(
+                "_get_backfill_events: connected_insertion_event_query %s",
+                connected_insertion_event_id_results,
+            )
+            for row in connected_insertion_event_id_results:
+                connected_insertion_event_depth = row[0]
+                connected_insertion_event = row[1]
+                queue.put((-connected_insertion_event_depth, connected_insertion_event))
+
+                # Find any chunk connections for the given insertion event
+                txn.execute(
+                    chunk_connection_query,
+                    (connected_insertion_event, limit - len(event_results)),
+                )
+                chunk_start_event_id_results = txn.fetchall()
+                logger.debug(
+                    "_get_backfill_events: chunk_start_event_id_results %s",
+                    chunk_start_event_id_results,
+                )
+                for row in chunk_start_event_id_results:
+                    if row[1] not in event_results:
+                        queue.put((-row[0], row[1]))
+
+            # Navigate up the DAG by prev_event
             txn.execute(query, (event_id, False, limit - len(event_results)))
+            prev_event_id_results = txn.fetchall()
+            logger.debug(
+                "_get_backfill_events: prev_event_ids %s", prev_event_id_results
+            )
 
-            for row in txn:
+            for row in prev_event_id_results:
                 if row[1] not in event_results:
                     queue.put((-row[0], row[1]))
 
@@ -1227,12 +1297,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             (count,) = txn.fetchone()
 
             txn.execute(
-                "SELECT coalesce(min(received_ts), 0) FROM federation_inbound_events_staging"
+                "SELECT min(received_ts) FROM federation_inbound_events_staging"
             )
 
             (received_ts,) = txn.fetchone()
 
-            age = self._clock.time_msec() - received_ts
+            # If there is nothing in the staging area default it to 0.
+            age = 0
+            if received_ts is not None:
+                age = self._clock.time_msec() - received_ts
 
             return count, age
 
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index a396a201d4..86baf397fb 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1502,6 +1502,9 @@ class PersistEventsStore:
 
             self._handle_event_relations(txn, event)
 
+            self._handle_insertion_event(txn, event)
+            self._handle_chunk_event(txn, event)
+
             # Store the labels for this event.
             labels = event.content.get(EventContentFields.LABELS)
             if labels:
@@ -1754,6 +1757,94 @@ class PersistEventsStore:
         if rel_type == RelationTypes.REPLACE:
             txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
 
+    def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
+        """Handles keeping track of insertion events and edges/connections.
+        Part of MSC2716.
+
+        Args:
+            txn: The database transaction object
+            event: The event to process
+        """
+
+        if event.type != EventTypes.MSC2716_INSERTION:
+            # Not a insertion event
+            return
+
+        # Skip processing a insertion event if the room version doesn't
+        # support it.
+        room_version = self.store.get_room_version_txn(txn, event.room_id)
+        if not room_version.msc2716_historical:
+            return
+
+        next_chunk_id = event.content.get(EventContentFields.MSC2716_NEXT_CHUNK_ID)
+        if next_chunk_id is None:
+            # Invalid insertion event without next chunk ID
+            return
+
+        logger.debug(
+            "_handle_insertion_event (next_chunk_id=%s) %s", next_chunk_id, event
+        )
+
+        # Keep track of the insertion event and the chunk ID
+        self.db_pool.simple_insert_txn(
+            txn,
+            table="insertion_events",
+            values={
+                "event_id": event.event_id,
+                "room_id": event.room_id,
+                "next_chunk_id": next_chunk_id,
+            },
+        )
+
+        # Insert an edge for every prev_event connection
+        for prev_event_id in event.prev_events:
+            self.db_pool.simple_insert_txn(
+                txn,
+                table="insertion_event_edges",
+                values={
+                    "event_id": event.event_id,
+                    "room_id": event.room_id,
+                    "insertion_prev_event_id": prev_event_id,
+                },
+            )
+
+    def _handle_chunk_event(self, txn: LoggingTransaction, event: EventBase):
+        """Handles inserting the chunk edges/connections between the chunk event
+        and an insertion event. Part of MSC2716.
+
+        Args:
+            txn: The database transaction object
+            event: The event to process
+        """
+
+        if event.type != EventTypes.MSC2716_CHUNK:
+            # Not a chunk event
+            return
+
+        # Skip processing a chunk event if the room version doesn't
+        # support it.
+        room_version = self.store.get_room_version_txn(txn, event.room_id)
+        if not room_version.msc2716_historical:
+            return
+
+        chunk_id = event.content.get(EventContentFields.MSC2716_CHUNK_ID)
+        if chunk_id is None:
+            # Invalid chunk event without a chunk ID
+            return
+
+        logger.debug("_handle_chunk_event chunk_id=%s %s", chunk_id, event)
+
+        # Keep track of the insertion event and the chunk ID
+        self.db_pool.simple_insert_txn(
+            txn,
+            table="chunk_events",
+            values={
+                "event_id": event.event_id,
+                "room_id": event.room_id,
+                "chunk_id": chunk_id,
+            },
+        )
+
     def _handle_redaction(self, txn, redacted_event_id):
         """Handles receiving a redaction and checking whether we need to remove
         any redacted relations from the database.
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index fe25638289..d213b26703 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -297,17 +297,13 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
         Args:
             txn (cursor):
             user_id (str): user to add/update
-
-        Returns:
-            bool: True if a new entry was created, False if an
-            existing one was updated.
         """
 
         # Am consciously deciding to lock the table on the basis that is ought
         # never be a big table and alternative approaches (batching multiple
         # upserts into a single txn) introduced a lot of extra complexity.
         # See https://github.com/matrix-org/synapse/issues/3854 for more
-        is_insert = self.db_pool.simple_upsert_txn(
+        self.db_pool.simple_upsert_txn(
             txn,
             table="monthly_active_users",
             keyvalues={"user_id": user_id},
@@ -322,8 +318,6 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
             txn, self.user_last_seen_monthly_active, (user_id,)
         )
 
-        return is_insert
-
     async def populate_monthly_active_users(self, user_id):
         """Checks on the state of monthly active user limits and optionally
         add the user to the monthly active tables
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 6ddafe5434..443e5f3315 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -363,7 +363,7 @@ class RoomWorkerStore(SQLBaseStore):
         self,
         start: int,
         limit: int,
-        order_by: RoomSortOrder,
+        order_by: str,
         reverse_order: bool,
         search_term: Optional[str],
     ) -> Tuple[List[Dict[str, Any]], int]:
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 1757064a68..8e22da99ae 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -22,7 +22,7 @@ from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.state import StateFilter
@@ -58,15 +58,32 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     async def get_room_version(self, room_id: str) -> RoomVersion:
         """Get the room_version of a given room
-
         Raises:
             NotFoundError: if the room is unknown
+            UnsupportedRoomVersionError: if the room uses an unknown room version.
+                Typically this happens if support for the room's version has been
+                removed from Synapse.
+        """
+        return await self.db_pool.runInteraction(
+            "get_room_version_txn",
+            self.get_room_version_txn,
+            room_id,
+        )
 
+    def get_room_version_txn(
+        self, txn: LoggingTransaction, room_id: str
+    ) -> RoomVersion:
+        """Get the room_version of a given room
+        Args:
+            txn: Transaction object
+            room_id: The room_id of the room you are trying to get the version for
+        Raises:
+            NotFoundError: if the room is unknown
             UnsupportedRoomVersionError: if the room uses an unknown room version.
                 Typically this happens if support for the room's version has been
                 removed from Synapse.
         """
-        room_version_id = await self.get_room_version_id(room_id)
+        room_version_id = self.get_room_version_id_txn(txn, room_id)
         v = KNOWN_ROOM_VERSIONS.get(room_version_id)
 
         if not v:
@@ -80,7 +97,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
     @cached(max_entries=10000)
     async def get_room_version_id(self, room_id: str) -> str:
         """Get the room_version of a given room
+        Raises:
+            NotFoundError: if the room is unknown
+        """
+        return await self.db_pool.runInteraction(
+            "get_room_version_id_txn",
+            self.get_room_version_id_txn,
+            room_id,
+        )
 
+    def get_room_version_id_txn(self, txn: LoggingTransaction, room_id: str) -> str:
+        """Get the room_version of a given room
+        Args:
+            txn: Transaction object
+            room_id: The room_id of the room you are trying to get the version for
         Raises:
             NotFoundError: if the room is unknown
         """
@@ -88,24 +118,22 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         # First we try looking up room version from the database, but for old
         # rooms we might not have added the room version to it yet so we fall
         # back to previous behaviour and look in current state events.
-
+        #
         # We really should have an entry in the rooms table for every room we
         # care about, but let's be a bit paranoid (at least while the background
         # update is happening) to avoid breaking existing rooms.
-        version = await self.db_pool.simple_select_one_onecol(
+        room_version = self.db_pool.simple_select_one_onecol_txn(
+            txn,
             table="rooms",
             keyvalues={"room_id": room_id},
             retcol="room_version",
-            desc="get_room_version",
             allow_none=True,
         )
 
-        if version is not None:
-            return version
+        if room_version is None:
+            raise NotFoundError("Could not room_version for %s" % (room_id,))
 
-        # Retrieve the room's create event
-        create_event = await self.get_create_event_for_room(room_id)
-        return create_event.content.get("room_version", "1")
+        return room_version
 
     async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
         """Get the predecessor of an upgraded room if it exists.
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 59d67c255b..42edbcc057 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -75,6 +75,7 @@ class UserSortOrder(Enum):
     USER_TYPE = ordered alphabetically by `user_type`
     AVATAR_URL = ordered alphabetically by `avatar_url`
     SHADOW_BANNED = ordered by `shadow_banned`
+    CREATION_TS = ordered by `creation_ts`
     """
 
     MEDIA_LENGTH = "media_length"
@@ -88,6 +89,7 @@ class UserSortOrder(Enum):
     USER_TYPE = "user_type"
     AVATAR_URL = "avatar_url"
     SHADOW_BANNED = "shadow_banned"
+    CREATION_TS = "creation_ts"
 
 
 class StatsStore(StateDeltasStore):
@@ -647,10 +649,10 @@ class StatsStore(StateDeltasStore):
         limit: int,
         from_ts: Optional[int] = None,
         until_ts: Optional[int] = None,
-        order_by: Optional[UserSortOrder] = UserSortOrder.USER_ID.value,
+        order_by: Optional[str] = UserSortOrder.USER_ID.value,
         direction: Optional[str] = "f",
         search_term: Optional[str] = None,
-    ) -> Tuple[List[JsonDict], Dict[str, int]]:
+    ) -> Tuple[List[JsonDict], int]:
         """Function to retrieve a paginated list of users and their uploaded local media
         (size and number). This will return a json list of users and the
         total number of users matching the filter criteria.
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index d211c423b2..7728d5f102 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -134,16 +134,18 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
             response_dict: The response, to be encoded into JSON.
         """
 
-        await self.db_pool.simple_insert(
+        await self.db_pool.simple_upsert(
             table="received_transactions",
-            values={
+            keyvalues={
                 "transaction_id": transaction_id,
                 "origin": origin,
+            },
+            values={},
+            insertion_values={
                 "response_code": code,
                 "response_json": db_binary_type(encode_canonical_json(response_dict)),
                 "ts": self._clock.time_msec(),
             },
-            or_ignore=True,
             desc="set_received_txn_response",
         )
 
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index a6bfb4902a..9d28d69ac7 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -377,7 +377,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             avatar_url = None
 
         def _update_profile_in_user_dir_txn(txn):
-            new_entry = self.db_pool.simple_upsert_txn(
+            self.db_pool.simple_upsert_txn(
                 txn,
                 table="user_directory",
                 keyvalues={"user_id": user_id},
@@ -388,8 +388,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             if isinstance(self.database_engine, PostgresEngine):
                 # We weight the localpart most highly, then display name and finally
                 # server name
-                if self.database_engine.can_native_upsert:
-                    sql = """
+                sql = """
                         INSERT INTO user_directory_search(user_id, vector)
                         VALUES (?,
                             setweight(to_tsvector('simple', ?), 'A')
@@ -397,58 +396,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                             || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
                         ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
                     """
-                    txn.execute(
-                        sql,
-                        (
-                            user_id,
-                            get_localpart_from_id(user_id),
-                            get_domain_from_id(user_id),
-                            display_name,
-                        ),
-                    )
-                else:
-                    # TODO: Remove this code after we've bumped the minimum version
-                    # of postgres to always support upserts, so we can get rid of
-                    # `new_entry` usage
-                    if new_entry is True:
-                        sql = """
-                            INSERT INTO user_directory_search(user_id, vector)
-                            VALUES (?,
-                                setweight(to_tsvector('simple', ?), 'A')
-                                || setweight(to_tsvector('simple', ?), 'D')
-                                || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
-                            )
-                        """
-                        txn.execute(
-                            sql,
-                            (
-                                user_id,
-                                get_localpart_from_id(user_id),
-                                get_domain_from_id(user_id),
-                                display_name,
-                            ),
-                        )
-                    elif new_entry is False:
-                        sql = """
-                            UPDATE user_directory_search
-                            SET vector = setweight(to_tsvector('simple', ?), 'A')
-                                || setweight(to_tsvector('simple', ?), 'D')
-                                || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
-                            WHERE user_id = ?
-                        """
-                        txn.execute(
-                            sql,
-                            (
-                                get_localpart_from_id(user_id),
-                                get_domain_from_id(user_id),
-                                display_name,
-                                user_id,
-                            ),
-                        )
-                    else:
-                        raise RuntimeError(
-                            "upsert returned None when 'can_native_upsert' is False"
-                        )
+                txn.execute(
+                    sql,
+                    (
+                        user_id,
+                        get_localpart_from_id(user_id),
+                        get_domain_from_id(user_id),
+                        display_name,
+                    ),
+                )
             elif isinstance(self.database_engine, Sqlite3Engine):
                 value = "%s %s" % (user_id, display_name) if display_name else user_id
                 self.db_pool.simple_upsert_txn(
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index e38461adbc..f839c0c24f 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -372,18 +372,23 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             )
 
     async def store_state_group(
-        self, event_id, room_id, prev_group, delta_ids, current_state_ids
+        self,
+        event_id: str,
+        room_id: str,
+        prev_group: Optional[int],
+        delta_ids: Optional[StateMap[str]],
+        current_state_ids: StateMap[str],
     ) -> int:
         """Store a new set of state, returning a newly assigned state group.
 
         Args:
-            event_id (str): The event ID for which the state was calculated
-            room_id (str)
-            prev_group (int|None): A previous state group for the room, optional.
-            delta_ids (dict|None): The delta between state at `prev_group` and
+            event_id: The event ID for which the state was calculated
+            room_id
+            prev_group: A previous state group for the room, optional.
+            delta_ids: The delta between state at `prev_group` and
                 `current_state_ids`, if `prev_group` was given. Same format as
                 `current_state_ids`.
-            current_state_ids (dict): The state to store. Map of (type, state_key)
+            current_state_ids: The state to store. Map of (type, state_key)
                 to event_id.
 
         Returns:
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index a39877f0d5..0e8270746d 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -170,7 +170,9 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
             end_item = queue[-1]
         else:
             # need to make a new queue item
-            deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
+            deferred: ObservableDeferred[_PersistResult] = ObservableDeferred(
+                defer.Deferred(), consumeErrors=True
+            )
 
             end_item = _EventPersistQueueItem(
                 events_and_contexts=[],
diff --git a/synapse/storage/schema/main/delta/61/01insertion_event_lookups.sql b/synapse/storage/schema/main/delta/61/01insertion_event_lookups.sql
new file mode 100644
index 0000000000..7d7bafc631
--- /dev/null
+++ b/synapse/storage/schema/main/delta/61/01insertion_event_lookups.sql
@@ -0,0 +1,49 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add a table that keeps track of "insertion" events and
+-- their next_chunk_id's so we can navigate to the next chunk of history.
+CREATE TABLE IF NOT EXISTS insertion_events(
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    next_chunk_id TEXT NOT NULL
+);
+CREATE UNIQUE INDEX IF NOT EXISTS insertion_events_event_id ON insertion_events(event_id);
+CREATE INDEX IF NOT EXISTS insertion_events_next_chunk_id ON insertion_events(next_chunk_id);
+
+-- Add a table that keeps track of all of the events we are inserting between.
+-- We use this when navigating the DAG and when we hit an event which matches
+-- `insertion_prev_event_id`, it should backfill from the "insertion" event and
+-- navigate the historical messages from there.
+CREATE TABLE IF NOT EXISTS insertion_event_edges(
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    insertion_prev_event_id TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS insertion_event_edges_event_id ON insertion_event_edges(event_id);
+CREATE INDEX IF NOT EXISTS insertion_event_edges_insertion_room_id ON insertion_event_edges(room_id);
+CREATE INDEX IF NOT EXISTS insertion_event_edges_insertion_prev_event_id ON insertion_event_edges(insertion_prev_event_id);
+
+-- Add a table that keeps track of how each chunk is labeled. The chunks are
+-- connected together based on an insertion events `next_chunk_id`.
+CREATE TABLE IF NOT EXISTS chunk_events(
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    chunk_id TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS chunk_events_event_id ON chunk_events(event_id);
+CREATE INDEX IF NOT EXISTS chunk_events_chunk_id ON chunk_events(chunk_id);
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index f8fbba9d38..e5400d681a 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -570,8 +570,8 @@ class StateGroupStorage:
         event_id: str,
         room_id: str,
         prev_group: Optional[int],
-        delta_ids: Optional[dict],
-        current_state_ids: dict,
+        delta_ids: Optional[StateMap[str]],
+        current_state_ids: StateMap[str],
     ) -> int:
         """Store a new set of state, returning a newly assigned state group.
 
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 13d300588b..cf4005984b 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -47,20 +47,22 @@ class PaginationConfig:
     ) -> "PaginationConfig":
         direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
 
-        from_tok = parse_string(request, "from")
-        to_tok = parse_string(request, "to")
+        from_tok_str = parse_string(request, "from")
+        to_tok_str = parse_string(request, "to")
 
         try:
-            if from_tok == "END":
+            from_tok = None
+            if from_tok_str == "END":
                 from_tok = None  # For backwards compat.
-            elif from_tok:
-                from_tok = await StreamToken.from_string(store, from_tok)
+            elif from_tok_str:
+                from_tok = await StreamToken.from_string(store, from_tok_str)
         except Exception:
             raise SynapseError(400, "'from' parameter is invalid")
 
         try:
-            if to_tok:
-                to_tok = await StreamToken.from_string(store, to_tok)
+            to_tok = None
+            if to_tok_str:
+                to_tok = await StreamToken.from_string(store, to_tok_str)
         except Exception:
             raise SynapseError(400, "'to' parameter is invalid")
 
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 014db1355b..a3b65aee27 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -23,6 +23,7 @@ from typing import (
     Awaitable,
     Callable,
     Dict,
+    Generic,
     Hashable,
     Iterable,
     List,
@@ -39,6 +40,7 @@ from twisted.internet import defer
 from twisted.internet.defer import CancelledError
 from twisted.internet.interfaces import IReactorTime
 from twisted.python import failure
+from twisted.python.failure import Failure
 
 from synapse.logging.context import (
     PreserveLoggingContext,
@@ -49,8 +51,10 @@ from synapse.util import Clock, unwrapFirstError
 
 logger = logging.getLogger(__name__)
 
+_T = TypeVar("_T")
 
-class ObservableDeferred:
+
+class ObservableDeferred(Generic[_T]):
     """Wraps a deferred object so that we can add observer deferreds. These
     observer deferreds do not affect the callback chain of the original
     deferred.
@@ -68,7 +72,7 @@ class ObservableDeferred:
 
     __slots__ = ["_deferred", "_observers", "_result"]
 
-    def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False):
+    def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
         object.__setattr__(self, "_deferred", deferred)
         object.__setattr__(self, "_result", None)
         object.__setattr__(self, "_observers", set())
@@ -113,7 +117,7 @@ class ObservableDeferred:
 
         deferred.addCallbacks(callback, errback)
 
-    def observe(self) -> defer.Deferred:
+    def observe(self) -> "defer.Deferred[_T]":
         """Observe the underlying deferred.
 
         This returns a brand new deferred that is resolved when the underlying
@@ -121,7 +125,7 @@ class ObservableDeferred:
         effect the underlying deferred.
         """
         if not self._result:
-            d = defer.Deferred()
+            d: "defer.Deferred[_T]" = defer.Deferred()
 
             def remove(r):
                 self._observers.discard(d)
@@ -135,7 +139,7 @@ class ObservableDeferred:
             success, res = self._result
             return defer.succeed(res) if success else defer.fail(res)
 
-    def observers(self) -> List[defer.Deferred]:
+    def observers(self) -> "List[defer.Deferred[_T]]":
         return self._observers
 
     def has_called(self) -> bool:
@@ -144,7 +148,7 @@ class ObservableDeferred:
     def has_succeeded(self) -> bool:
         return self._result is not None and self._result[0] is True
 
-    def get_result(self) -> Any:
+    def get_result(self) -> Union[_T, Failure]:
         return self._result[1]
 
     def __getattr__(self, name: str) -> Any:
@@ -415,7 +419,7 @@ class ReadWriteLock:
         self.key_to_current_writer: Dict[str, defer.Deferred] = {}
 
     async def read(self, key: str) -> ContextManager:
-        new_defer = defer.Deferred()
+        new_defer: "defer.Deferred[None]" = defer.Deferred()
 
         curr_readers = self.key_to_current_readers.setdefault(key, set())
         curr_writer = self.key_to_current_writer.get(key, None)
@@ -438,7 +442,7 @@ class ReadWriteLock:
         return _ctx_manager()
 
     async def write(self, key: str) -> ContextManager:
-        new_defer = defer.Deferred()
+        new_defer: "defer.Deferred[None]" = defer.Deferred()
 
         curr_readers = self.key_to_current_readers.get(key, set())
         curr_writer = self.key_to_current_writer.get(key, None)
@@ -471,10 +475,8 @@ R = TypeVar("R")
 
 
 def timeout_deferred(
-    deferred: defer.Deferred,
-    timeout: float,
-    reactor: IReactorTime,
-) -> defer.Deferred:
+    deferred: "defer.Deferred[_T]", timeout: float, reactor: IReactorTime
+) -> "defer.Deferred[_T]":
     """The in built twisted `Deferred.addTimeout` fails to time out deferreds
     that have a canceller that throws exceptions. This method creates a new
     deferred that wraps and times out the given deferred, correctly handling
@@ -497,7 +499,7 @@ def timeout_deferred(
     Returns:
         A new Deferred, which will errback with defer.TimeoutError on timeout.
     """
-    new_d = defer.Deferred()
+    new_d: "defer.Deferred[_T]" = defer.Deferred()
 
     timed_out = [False]
 
diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py
index 891bee0b33..e58dd91eda 100644
--- a/synapse/util/caches/cached_call.py
+++ b/synapse/util/caches/cached_call.py
@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import enum
 from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union
 
 from twisted.internet.defer import Deferred
@@ -22,6 +22,10 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
 TV = TypeVar("TV")
 
 
+class _Sentinel(enum.Enum):
+    sentinel = object()
+
+
 class CachedCall(Generic[TV]):
     """A wrapper for asynchronous calls whose results should be shared
 
@@ -65,7 +69,7 @@ class CachedCall(Generic[TV]):
         """
         self._callable: Optional[Callable[[], Awaitable[TV]]] = f
         self._deferred: Optional[Deferred] = None
-        self._result: Union[None, Failure, TV] = None
+        self._result: Union[_Sentinel, TV, Failure] = _Sentinel.sentinel
 
     async def get(self) -> TV:
         """Kick off the call if necessary, and return the result"""
@@ -78,8 +82,9 @@ class CachedCall(Generic[TV]):
             self._callable = None
 
             # once the deferred completes, store the result. We cannot simply leave the
-            # result in the deferred, since if it's a Failure, GCing the deferred
-            # would then log a critical error about unhandled Failures.
+            # result in the deferred, since `awaiting` a deferred destroys its result.
+            # (Also, if it's a Failure, GCing the deferred would log a critical error
+            # about unhandled Failures)
             def got_result(r):
                 self._result = r
 
@@ -92,13 +97,15 @@ class CachedCall(Generic[TV]):
         #    and any eventual exception may not be reported.
 
         # we can now await the deferred, and once it completes, return the result.
-        await make_deferred_yieldable(self._deferred)
+        if isinstance(self._result, _Sentinel):
+            await make_deferred_yieldable(self._deferred)
+            assert not isinstance(self._result, _Sentinel)
+
+        if isinstance(self._result, Failure):
+            self._result.raiseException()
+            raise AssertionError("unexpected return from Failure.raiseException")
 
-        # I *think* this is the easiest way to correctly raise a Failure without having
-        # to gut-wrench into the implementation of Deferred.
-        d = Deferred()
-        d.callback(self._result)
-        return await d
+        return self._result
 
 
 class RetryOnExceptionCachedCall(Generic[TV]):
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 8c6fafc677..b6456392cd 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -16,7 +16,16 @@
 
 import enum
 import threading
-from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, Union
+from typing import (
+    Callable,
+    Generic,
+    Iterable,
+    MutableMapping,
+    Optional,
+    TypeVar,
+    Union,
+    cast,
+)
 
 from prometheus_client import Gauge
 
@@ -166,7 +175,7 @@ class DeferredCache(Generic[KT, VT]):
     def set(
         self,
         key: KT,
-        value: defer.Deferred,
+        value: "defer.Deferred[VT]",
         callback: Optional[Callable[[], None]] = None,
     ) -> defer.Deferred:
         """Adds a new entry to the cache (or updates an existing one).
@@ -214,7 +223,7 @@ class DeferredCache(Generic[KT, VT]):
         if value.called:
             result = value.result
             if not isinstance(result, failure.Failure):
-                self.cache.set(key, result, callbacks)
+                self.cache.set(key, cast(VT, result), callbacks)
             return value
 
         # otherwise, we'll add an entry to the _pending_deferred_cache for now,
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 1e8e6b1d01..1ca31e41ac 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -413,7 +413,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                 # relevant result for that key.
                 deferreds_map = {}
                 for arg in missing:
-                    deferred = defer.Deferred()
+                    deferred: "defer.Deferred[Any]" = defer.Deferred()
                     deferreds_map[arg] = deferred
                     key = arg_to_cache_key(arg)
                     cache.set(key, deferred, callback=invalidate_callback)