summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py5
-rw-r--r--synapse/config/registration.py21
-rw-r--r--synapse/federation/federation_server.py121
-rw-r--r--synapse/federation/transport/server.py592
-rw-r--r--synapse/handlers/auth.py132
-rw-r--r--synapse/handlers/federation.py219
-rw-r--r--synapse/handlers/register.py115
-rw-r--r--synapse/http/server.py2
-rw-r--r--synapse/http/servlet.py50
-rw-r--r--synapse/module_api/__init__.py2
-rw-r--r--synapse/replication/http/login.py13
-rw-r--r--synapse/rest/client/v1/login.py171
-rw-r--r--synapse/rest/client/v2_alpha/register.py88
-rw-r--r--synapse/rest/client/v2_alpha/sync.py69
-rw-r--r--synapse/storage/databases/main/registration.py207
-rw-r--r--synapse/storage/schema/main/delta/59/14refresh_tokens.sql34
16 files changed, 1450 insertions, 391 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index edf1b918eb..29cf257633 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -245,6 +245,11 @@ class Auth:
                     errcode=Codes.GUEST_ACCESS_FORBIDDEN,
                 )
 
+            # Mark the token as used. This is used to invalidate old refresh
+            # tokens after some time.
+            if not user_info.token_used and token_id is not None:
+                await self.store.mark_access_token_as_used(token_id)
+
             requester = create_requester(
                 user_info.user_id,
                 token_id,
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index d9dc55a0c3..0ad919b139 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -119,6 +119,27 @@ class RegistrationConfig(Config):
             session_lifetime = self.parse_duration(session_lifetime)
         self.session_lifetime = session_lifetime
 
+        # The `access_token_lifetime` applies for tokens that can be renewed
+        # using a refresh token, as per MSC2918. If it is `None`, the refresh
+        # token mechanism is disabled.
+        #
+        # Since it is incompatible with the `session_lifetime` mechanism, it is set to
+        # `None` by default if a `session_lifetime` is set.
+        access_token_lifetime = config.get(
+            "access_token_lifetime", "5m" if session_lifetime is None else None
+        )
+        if access_token_lifetime is not None:
+            access_token_lifetime = self.parse_duration(access_token_lifetime)
+        self.access_token_lifetime = access_token_lifetime
+
+        if session_lifetime is not None and access_token_lifetime is not None:
+            raise ConfigError(
+                "The refresh token mechanism is incompatible with the "
+                "`session_lifetime` option. Consider disabling the "
+                "`session_lifetime` option or disabling the refresh token "
+                "mechanism by removing the `access_token_lifetime` option."
+            )
+
         # The success template used during fallback auth.
         self.fallback_success_template = self.read_template("auth_success.html")
 
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 2b07f18529..341965047a 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -34,7 +34,7 @@ from twisted.internet import defer
 from twisted.internet.abstract import isIPAddress
 from twisted.python import failure
 
-from synapse.api.constants import EduTypes, EventTypes
+from synapse.api.constants import EduTypes, EventTypes, Membership
 from synapse.api.errors import (
     AuthError,
     Codes,
@@ -46,6 +46,7 @@ from synapse.api.errors import (
 )
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
 from synapse.federation.persistence import TransactionActions
 from synapse.federation.units import Edu, Transaction
@@ -537,26 +538,21 @@ class FederationServer(FederationBase):
         return {"event": ret_pdu.get_pdu_json(time_now)}
 
     async def on_send_join_request(
-        self, origin: str, content: JsonDict
+        self, origin: str, content: JsonDict, room_id: str
     ) -> Dict[str, Any]:
-        logger.debug("on_send_join_request: content: %s", content)
-
-        assert_params_in_dict(content, ["room_id"])
-        room_version = await self.store.get_room_version(content["room_id"])
-        pdu = event_from_pdu_json(content, room_version)
-
-        origin_host, _ = parse_server_name(origin)
-        await self.check_server_matches_acl(origin_host, pdu.room_id)
-
-        logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
+        context = await self._on_send_membership_event(
+            origin, content, Membership.JOIN, room_id
+        )
 
-        pdu = await self._check_sigs_and_hash(room_version, pdu)
+        prev_state_ids = await context.get_prev_state_ids()
+        state_ids = list(prev_state_ids.values())
+        auth_chain = await self.store.get_auth_chain(room_id, state_ids)
+        state = await self.store.get_events(state_ids)
 
-        res_pdus = await self.handler.on_send_join_request(origin, pdu)
         time_now = self._clock.time_msec()
         return {
-            "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
-            "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
+            "state": [p.get_pdu_json(time_now) for p in state.values()],
+            "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
         }
 
     async def on_make_leave_request(
@@ -571,21 +567,11 @@ class FederationServer(FederationBase):
         time_now = self._clock.time_msec()
         return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
 
-    async def on_send_leave_request(self, origin: str, content: JsonDict) -> dict:
+    async def on_send_leave_request(
+        self, origin: str, content: JsonDict, room_id: str
+    ) -> dict:
         logger.debug("on_send_leave_request: content: %s", content)
-
-        assert_params_in_dict(content, ["room_id"])
-        room_version = await self.store.get_room_version(content["room_id"])
-        pdu = event_from_pdu_json(content, room_version)
-
-        origin_host, _ = parse_server_name(origin)
-        await self.check_server_matches_acl(origin_host, pdu.room_id)
-
-        logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
-
-        pdu = await self._check_sigs_and_hash(room_version, pdu)
-
-        await self.handler.on_send_leave_request(origin, pdu)
+        await self._on_send_membership_event(origin, content, Membership.LEAVE, room_id)
         return {}
 
     async def on_make_knock_request(
@@ -651,39 +637,76 @@ class FederationServer(FederationBase):
         Returns:
             The stripped room state.
         """
-        logger.debug("on_send_knock_request: content: %s", content)
+        event_context = await self._on_send_membership_event(
+            origin, content, Membership.KNOCK, room_id
+        )
+
+        # Retrieve stripped state events from the room and send them back to the remote
+        # server. This will allow the remote server's clients to display information
+        # related to the room while the knock request is pending.
+        stripped_room_state = (
+            await self.store.get_stripped_room_state_from_event_context(
+                event_context, self._room_prejoin_state_types
+            )
+        )
+        return {"knock_state_events": stripped_room_state}
+
+    async def _on_send_membership_event(
+        self, origin: str, content: JsonDict, membership_type: str, room_id: str
+    ) -> EventContext:
+        """Handle an on_send_{join,leave,knock} request
+
+        Does some preliminary validation before passing the request on to the
+        federation handler.
+
+        Args:
+            origin: The (authenticated) requesting server
+            content: The body of the send_* request - a complete membership event
+            membership_type: The expected membership type (join or leave, depending
+                on the endpoint)
+            room_id: The room_id from the request, to be validated against the room_id
+                in the event
+
+        Returns:
+            The context of the event after inserting it into the room graph.
+
+        Raises:
+            SynapseError if there is a problem with the request, including things like
+               the room_id not matching or the event not being authorized.
+        """
+        assert_params_in_dict(content, ["room_id"])
+        if content["room_id"] != room_id:
+            raise SynapseError(
+                400,
+                "Room ID in body does not match that in request path",
+                Codes.BAD_JSON,
+            )
 
         room_version = await self.store.get_room_version(room_id)
 
-        # Check that this room supports knocking as defined by its room version
-        if not room_version.msc2403_knocking:
+        if membership_type == Membership.KNOCK and not room_version.msc2403_knocking:
             raise SynapseError(
                 403,
                 "This room version does not support knocking",
                 errcode=Codes.FORBIDDEN,
             )
 
-        pdu = event_from_pdu_json(content, room_version)
+        event = event_from_pdu_json(content, room_version)
 
-        origin_host, _ = parse_server_name(origin)
-        await self.check_server_matches_acl(origin_host, pdu.room_id)
+        if event.type != EventTypes.Member or not event.is_state():
+            raise SynapseError(400, "Not an m.room.member event", Codes.BAD_JSON)
 
-        logger.debug("on_send_knock_request: pdu sigs: %s", pdu.signatures)
+        if event.content.get("membership") != membership_type:
+            raise SynapseError(400, "Not a %s event" % membership_type, Codes.BAD_JSON)
 
-        pdu = await self._check_sigs_and_hash(room_version, pdu)
+        origin_host, _ = parse_server_name(origin)
+        await self.check_server_matches_acl(origin_host, event.room_id)
 
-        # Handle the event, and retrieve the EventContext
-        event_context = await self.handler.on_send_knock_request(origin, pdu)
+        logger.debug("_on_send_membership_event: pdu sigs: %s", event.signatures)
 
-        # Retrieve stripped state events from the room and send them back to the remote
-        # server. This will allow the remote server's clients to display information
-        # related to the room while the knock request is pending.
-        stripped_room_state = (
-            await self.store.get_stripped_room_state_from_event_context(
-                event_context, self._room_prejoin_state_types
-            )
-        )
-        return {"knock_state_events": stripped_room_state}
+        event = await self._check_sigs_and_hash(room_version, event)
+
+        return await self.handler.on_send_membership_event(origin, event)
 
     async def on_event_auth(
         self, origin: str, room_id: str, event_id: str
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index bed47f8abd..d37d9565fc 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -15,7 +15,19 @@
 import functools
 import logging
 import re
-from typing import Container, Mapping, Optional, Sequence, Tuple, Type
+from typing import (
+    Container,
+    Dict,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+    Tuple,
+    Type,
+    Union,
+)
+
+from typing_extensions import Literal
 
 import synapse
 from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH
@@ -56,15 +68,15 @@ logger = logging.getLogger(__name__)
 class TransportLayerServer(JsonResource):
     """Handles incoming federation HTTP requests"""
 
-    def __init__(self, hs, servlet_groups=None):
+    def __init__(self, hs: HomeServer, servlet_groups: Optional[List[str]] = None):
         """Initialize the TransportLayerServer
 
         Will by default register all servlets. For custom behaviour, pass in
         a list of servlet_groups to register.
 
         Args:
-            hs (synapse.server.HomeServer): homeserver
-            servlet_groups (list[str], optional): List of servlet groups to register.
+            hs: homeserver
+            servlet_groups: List of servlet groups to register.
                 Defaults to ``DEFAULT_SERVLET_GROUPS``.
         """
         self.hs = hs
@@ -78,7 +90,7 @@ class TransportLayerServer(JsonResource):
 
         self.register_servlets()
 
-    def register_servlets(self):
+    def register_servlets(self) -> None:
         register_servlets(
             self.hs,
             resource=self,
@@ -91,14 +103,10 @@ class TransportLayerServer(JsonResource):
 class AuthenticationError(SynapseError):
     """There was a problem authenticating the request"""
 
-    pass
-
 
 class NoAuthenticationError(AuthenticationError):
     """The request had no authentication information"""
 
-    pass
-
 
 class Authenticator:
     def __init__(self, hs: HomeServer):
@@ -410,13 +418,18 @@ class FederationSendServlet(BaseFederationServerServlet):
     RATELIMIT = False
 
     # This is when someone is trying to send us a bunch of data.
-    async def on_PUT(self, origin, content, query, transaction_id):
+    async def on_PUT(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        transaction_id: str,
+    ) -> Tuple[int, JsonDict]:
         """Called on PUT /send/<transaction_id>/
 
         Args:
-            request (twisted.web.http.Request): The HTTP request.
-            transaction_id (str): The transaction_id associated with this
-                request. This is *not* None.
+            transaction_id: The transaction_id associated with this request. This
+                is *not* None.
 
         Returns:
             Tuple of `(code, response)`, where
@@ -461,7 +474,13 @@ class FederationEventServlet(BaseFederationServerServlet):
     PATH = "/event/(?P<event_id>[^/]*)/?"
 
     # This is when someone asks for a data item for a given server data_id pair.
-    async def on_GET(self, origin, content, query, event_id):
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        event_id: str,
+    ) -> Tuple[int, Union[JsonDict, str]]:
         return await self.handler.on_pdu_request(origin, event_id)
 
 
@@ -469,7 +488,13 @@ class FederationStateV1Servlet(BaseFederationServerServlet):
     PATH = "/state/(?P<room_id>[^/]*)/?"
 
     # This is when someone asks for all data for a given room.
-    async def on_GET(self, origin, content, query, room_id):
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        room_id: str,
+    ) -> Tuple[int, JsonDict]:
         return await self.handler.on_room_state_request(
             origin,
             room_id,
@@ -480,7 +505,13 @@ class FederationStateV1Servlet(BaseFederationServerServlet):
 class FederationStateIdsServlet(BaseFederationServerServlet):
     PATH = "/state_ids/(?P<room_id>[^/]*)/?"
 
-    async def on_GET(self, origin, content, query, room_id):
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        room_id: str,
+    ) -> Tuple[int, JsonDict]:
         return await self.handler.on_state_ids_request(
             origin,
             room_id,
@@ -491,7 +522,13 @@ class FederationStateIdsServlet(BaseFederationServerServlet):
 class FederationBackfillServlet(BaseFederationServerServlet):
     PATH = "/backfill/(?P<room_id>[^/]*)/?"
 
-    async def on_GET(self, origin, content, query, room_id):
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        room_id: str,
+    ) -> Tuple[int, JsonDict]:
         versions = [x.decode("ascii") for x in query[b"v"]]
         limit = parse_integer_from_args(query, "limit", None)
 
@@ -505,7 +542,13 @@ class FederationQueryServlet(BaseFederationServerServlet):
     PATH = "/query/(?P<query_type>[^/]*)"
 
     # This is when we receive a server-server Query
-    async def on_GET(self, origin, content, query, query_type):
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        query_type: str,
+    ) -> Tuple[int, JsonDict]:
         args = {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()}
         args["origin"] = origin
         return await self.handler.on_query_request(query_type, args)
@@ -514,47 +557,66 @@ class FederationQueryServlet(BaseFederationServerServlet):
 class FederationMakeJoinServlet(BaseFederationServerServlet):
     PATH = "/make_join/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
 
-    async def on_GET(self, origin, _content, query, room_id, user_id):
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        room_id: str,
+        user_id: str,
+    ) -> Tuple[int, JsonDict]:
         """
         Args:
-            origin (unicode): The authenticated server_name of the calling server
+            origin: The authenticated server_name of the calling server
 
-            _content (None): (GETs don't have bodies)
+            content: (GETs don't have bodies)
 
-            query (dict[bytes, list[bytes]]): Query params from the request.
+            query: Query params from the request.
 
-            **kwargs (dict[unicode, unicode]): the dict mapping keys to path
-                components as specified in the path match regexp.
+            **kwargs: the dict mapping keys to path components as specified in
+                the path match regexp.
 
         Returns:
-            Tuple[int, object]: (response code, response object)
+            Tuple of (response code, response object)
         """
-        versions = query.get(b"ver")
-        if versions is not None:
-            supported_versions = [v.decode("utf-8") for v in versions]
-        else:
+        supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8")
+        if supported_versions is None:
             supported_versions = ["1"]
 
-        content = await self.handler.on_make_join_request(
+        result = await self.handler.on_make_join_request(
             origin, room_id, user_id, supported_versions=supported_versions
         )
-        return 200, content
+        return 200, result
 
 
 class FederationMakeLeaveServlet(BaseFederationServerServlet):
     PATH = "/make_leave/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
 
-    async def on_GET(self, origin, content, query, room_id, user_id):
-        content = await self.handler.on_make_leave_request(origin, room_id, user_id)
-        return 200, content
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        room_id: str,
+        user_id: str,
+    ) -> Tuple[int, JsonDict]:
+        result = await self.handler.on_make_leave_request(origin, room_id, user_id)
+        return 200, result
 
 
 class FederationV1SendLeaveServlet(BaseFederationServerServlet):
     PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
 
-    async def on_PUT(self, origin, content, query, room_id, event_id):
-        content = await self.handler.on_send_leave_request(origin, content)
-        return 200, (200, content)
+    async def on_PUT(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        room_id: str,
+        event_id: str,
+    ) -> Tuple[int, Tuple[int, JsonDict]]:
+        result = await self.handler.on_send_leave_request(origin, content, room_id)
+        return 200, (200, result)
 
 
 class FederationV2SendLeaveServlet(BaseFederationServerServlet):
@@ -562,50 +624,84 @@ class FederationV2SendLeaveServlet(BaseFederationServerServlet):
 
     PREFIX = FEDERATION_V2_PREFIX
 
-    async def on_PUT(self, origin, content, query, room_id, event_id):
-        content = await self.handler.on_send_leave_request(origin, content)
-        return 200, content
+    async def on_PUT(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        room_id: str,
+        event_id: str,
+    ) -> Tuple[int, JsonDict]:
+        result = await self.handler.on_send_leave_request(origin, content, room_id)
+        return 200, result
 
 
 class FederationMakeKnockServlet(BaseFederationServerServlet):
     PATH = "/make_knock/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)"
 
-    async def on_GET(self, origin, content, query, room_id, user_id):
-        try:
-            # Retrieve the room versions the remote homeserver claims to support
-            supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8")
-        except KeyError:
-            raise SynapseError(400, "Missing required query parameter 'ver'")
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        room_id: str,
+        user_id: str,
+    ) -> Tuple[int, JsonDict]:
+        # Retrieve the room versions the remote homeserver claims to support
+        supported_versions = parse_strings_from_args(
+            query, "ver", required=True, encoding="utf-8"
+        )
 
-        content = await self.handler.on_make_knock_request(
+        result = await self.handler.on_make_knock_request(
             origin, room_id, user_id, supported_versions=supported_versions
         )
-        return 200, content
+        return 200, result
 
 
 class FederationV1SendKnockServlet(BaseFederationServerServlet):
     PATH = "/send_knock/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
 
-    async def on_PUT(self, origin, content, query, room_id, event_id):
-        content = await self.handler.on_send_knock_request(origin, content, room_id)
-        return 200, content
+    async def on_PUT(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        room_id: str,
+        event_id: str,
+    ) -> Tuple[int, JsonDict]:
+        result = await self.handler.on_send_knock_request(origin, content, room_id)
+        return 200, result
 
 
 class FederationEventAuthServlet(BaseFederationServerServlet):
     PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
 
-    async def on_GET(self, origin, content, query, room_id, event_id):
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        room_id: str,
+        event_id: str,
+    ) -> Tuple[int, JsonDict]:
         return await self.handler.on_event_auth(origin, room_id, event_id)
 
 
 class FederationV1SendJoinServlet(BaseFederationServerServlet):
     PATH = "/send_join/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
 
-    async def on_PUT(self, origin, content, query, room_id, event_id):
-        # TODO(paul): assert that room_id/event_id parsed from path actually
+    async def on_PUT(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        room_id: str,
+        event_id: str,
+    ) -> Tuple[int, Tuple[int, JsonDict]]:
+        # TODO(paul): assert that event_id parsed from path actually
         #   match those given in content
-        content = await self.handler.on_send_join_request(origin, content)
-        return 200, (200, content)
+        result = await self.handler.on_send_join_request(origin, content, room_id)
+        return 200, (200, result)
 
 
 class FederationV2SendJoinServlet(BaseFederationServerServlet):
@@ -613,28 +709,42 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
 
     PREFIX = FEDERATION_V2_PREFIX
 
-    async def on_PUT(self, origin, content, query, room_id, event_id):
-        # TODO(paul): assert that room_id/event_id parsed from path actually
+    async def on_PUT(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        room_id: str,
+        event_id: str,
+    ) -> Tuple[int, JsonDict]:
+        # TODO(paul): assert that event_id parsed from path actually
         #   match those given in content
-        content = await self.handler.on_send_join_request(origin, content)
-        return 200, content
+        result = await self.handler.on_send_join_request(origin, content, room_id)
+        return 200, result
 
 
 class FederationV1InviteServlet(BaseFederationServerServlet):
     PATH = "/invite/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
 
-    async def on_PUT(self, origin, content, query, room_id, event_id):
+    async def on_PUT(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        room_id: str,
+        event_id: str,
+    ) -> Tuple[int, Tuple[int, JsonDict]]:
         # We don't get a room version, so we have to assume its EITHER v1 or
         # v2. This is "fine" as the only difference between V1 and V2 is the
         # state resolution algorithm, and we don't use that for processing
         # invites
-        content = await self.handler.on_invite_request(
+        result = await self.handler.on_invite_request(
             origin, content, room_version_id=RoomVersions.V1.identifier
         )
 
         # V1 federation API is defined to return a content of `[200, {...}]`
         # due to a historical bug.
-        return 200, (200, content)
+        return 200, (200, result)
 
 
 class FederationV2InviteServlet(BaseFederationServerServlet):
@@ -642,7 +752,14 @@ class FederationV2InviteServlet(BaseFederationServerServlet):
 
     PREFIX = FEDERATION_V2_PREFIX
 
-    async def on_PUT(self, origin, content, query, room_id, event_id):
+    async def on_PUT(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        room_id: str,
+        event_id: str,
+    ) -> Tuple[int, JsonDict]:
         # TODO(paul): assert that room_id/event_id parsed from path actually
         #   match those given in content
 
@@ -655,16 +772,22 @@ class FederationV2InviteServlet(BaseFederationServerServlet):
 
         event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state
 
-        content = await self.handler.on_invite_request(
+        result = await self.handler.on_invite_request(
             origin, event, room_version_id=room_version
         )
-        return 200, content
+        return 200, result
 
 
 class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet):
     PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
 
-    async def on_PUT(self, origin, content, query, room_id):
+    async def on_PUT(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        room_id: str,
+    ) -> Tuple[int, JsonDict]:
         await self.handler.on_exchange_third_party_invite_request(content)
         return 200, {}
 
@@ -672,21 +795,31 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet):
 class FederationClientKeysQueryServlet(BaseFederationServerServlet):
     PATH = "/user/keys/query"
 
-    async def on_POST(self, origin, content, query):
+    async def on_POST(
+        self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
+    ) -> Tuple[int, JsonDict]:
         return await self.handler.on_query_client_keys(origin, content)
 
 
 class FederationUserDevicesQueryServlet(BaseFederationServerServlet):
     PATH = "/user/devices/(?P<user_id>[^/]*)"
 
-    async def on_GET(self, origin, content, query, user_id):
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        user_id: str,
+    ) -> Tuple[int, JsonDict]:
         return await self.handler.on_query_user_devices(origin, user_id)
 
 
 class FederationClientKeysClaimServlet(BaseFederationServerServlet):
     PATH = "/user/keys/claim"
 
-    async def on_POST(self, origin, content, query):
+    async def on_POST(
+        self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
+    ) -> Tuple[int, JsonDict]:
         response = await self.handler.on_claim_client_keys(origin, content)
         return 200, response
 
@@ -695,12 +828,18 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet):
     # TODO(paul): Why does this path alone end with "/?" optional?
     PATH = "/get_missing_events/(?P<room_id>[^/]*)/?"
 
-    async def on_POST(self, origin, content, query, room_id):
+    async def on_POST(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        room_id: str,
+    ) -> Tuple[int, JsonDict]:
         limit = int(content.get("limit", 10))
         earliest_events = content.get("earliest_events", [])
         latest_events = content.get("latest_events", [])
 
-        content = await self.handler.on_get_missing_events(
+        result = await self.handler.on_get_missing_events(
             origin,
             room_id=room_id,
             earliest_events=earliest_events,
@@ -708,7 +847,7 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet):
             limit=limit,
         )
 
-        return 200, content
+        return 200, result
 
 
 class On3pidBindServlet(BaseFederationServerServlet):
@@ -716,7 +855,9 @@ class On3pidBindServlet(BaseFederationServerServlet):
 
     REQUIRE_AUTH = False
 
-    async def on_POST(self, origin, content, query):
+    async def on_POST(
+        self, origin: Optional[str], content: JsonDict, query: Dict[bytes, List[bytes]]
+    ) -> Tuple[int, JsonDict]:
         if "invites" in content:
             last_exception = None
             for invite in content["invites"]:
@@ -762,15 +903,20 @@ class OpenIdUserInfo(BaseFederationServerServlet):
 
     REQUIRE_AUTH = False
 
-    async def on_GET(self, origin, content, query):
-        token = query.get(b"access_token", [None])[0]
+    async def on_GET(
+        self,
+        origin: Optional[str],
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+    ) -> Tuple[int, JsonDict]:
+        token = parse_string_from_args(query, "access_token")
         if token is None:
             return (
                 401,
                 {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"},
             )
 
-        user_id = await self.handler.on_openid_userinfo(token.decode("ascii"))
+        user_id = await self.handler.on_openid_userinfo(token)
 
         if user_id is None:
             return (
@@ -829,7 +975,9 @@ class PublicRoomList(BaseFederationServlet):
         self.handler = hs.get_room_list_handler()
         self.allow_access = allow_access
 
-    async def on_GET(self, origin, content, query):
+    async def on_GET(
+        self, origin: str, content: Literal[None], query: Dict[bytes, List[bytes]]
+    ) -> Tuple[int, JsonDict]:
         if not self.allow_access:
             raise FederationDeniedError(origin)
 
@@ -858,7 +1006,9 @@ class PublicRoomList(BaseFederationServlet):
         )
         return 200, data
 
-    async def on_POST(self, origin, content, query):
+    async def on_POST(
+        self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
+    ) -> Tuple[int, JsonDict]:
         # This implements MSC2197 (Search Filtering over Federation)
         if not self.allow_access:
             raise FederationDeniedError(origin)
@@ -904,7 +1054,12 @@ class FederationVersionServlet(BaseFederationServlet):
 
     REQUIRE_AUTH = False
 
-    async def on_GET(self, origin, content, query):
+    async def on_GET(
+        self,
+        origin: Optional[str],
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+    ) -> Tuple[int, JsonDict]:
         return (
             200,
             {"server": {"name": "Synapse", "version": get_version_string(synapse)}},
@@ -933,7 +1088,13 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet):
 
     PATH = "/groups/(?P<group_id>[^/]*)/profile"
 
-    async def on_GET(self, origin, content, query, group_id):
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -942,7 +1103,13 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet):
 
         return 200, new_content
 
-    async def on_POST(self, origin, content, query, group_id):
+    async def on_POST(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -957,7 +1124,13 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet):
 class FederationGroupsSummaryServlet(BaseGroupsServerServlet):
     PATH = "/groups/(?P<group_id>[^/]*)/summary"
 
-    async def on_GET(self, origin, content, query, group_id):
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -972,7 +1145,13 @@ class FederationGroupsRoomsServlet(BaseGroupsServerServlet):
 
     PATH = "/groups/(?P<group_id>[^/]*)/rooms"
 
-    async def on_GET(self, origin, content, query, group_id):
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -987,7 +1166,14 @@ class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet):
 
     PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
 
-    async def on_POST(self, origin, content, query, group_id, room_id):
+    async def on_POST(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+        room_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -998,7 +1184,14 @@ class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet):
 
         return 200, new_content
 
-    async def on_DELETE(self, origin, content, query, group_id, room_id):
+    async def on_DELETE(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+        room_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1018,7 +1211,15 @@ class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet):
         "/config/(?P<config_key>[^/]*)"
     )
 
-    async def on_POST(self, origin, content, query, group_id, room_id, config_key):
+    async def on_POST(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+        room_id: str,
+        config_key: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1035,7 +1236,13 @@ class FederationGroupsUsersServlet(BaseGroupsServerServlet):
 
     PATH = "/groups/(?P<group_id>[^/]*)/users"
 
-    async def on_GET(self, origin, content, query, group_id):
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1050,7 +1257,13 @@ class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet):
 
     PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
 
-    async def on_GET(self, origin, content, query, group_id):
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1067,7 +1280,14 @@ class FederationGroupsInviteServlet(BaseGroupsServerServlet):
 
     PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
 
-    async def on_POST(self, origin, content, query, group_id, user_id):
+    async def on_POST(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+        user_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1084,7 +1304,14 @@ class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet):
 
     PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
 
-    async def on_POST(self, origin, content, query, group_id, user_id):
+    async def on_POST(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+        user_id: str,
+    ) -> Tuple[int, JsonDict]:
         if get_domain_from_id(user_id) != origin:
             raise SynapseError(403, "user_id doesn't match origin")
 
@@ -1098,7 +1325,14 @@ class FederationGroupsJoinServlet(BaseGroupsServerServlet):
 
     PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
 
-    async def on_POST(self, origin, content, query, group_id, user_id):
+    async def on_POST(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+        user_id: str,
+    ) -> Tuple[int, JsonDict]:
         if get_domain_from_id(user_id) != origin:
             raise SynapseError(403, "user_id doesn't match origin")
 
@@ -1112,7 +1346,14 @@ class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet):
 
     PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
 
-    async def on_POST(self, origin, content, query, group_id, user_id):
+    async def on_POST(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+        user_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1146,7 +1387,14 @@ class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet):
 
     PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
 
-    async def on_POST(self, origin, content, query, group_id, user_id):
+    async def on_POST(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+        user_id: str,
+    ) -> Tuple[int, JsonDict]:
         if get_domain_from_id(group_id) != origin:
             raise SynapseError(403, "group_id doesn't match origin")
 
@@ -1164,7 +1412,14 @@ class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet):
 
     PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
 
-    async def on_POST(self, origin, content, query, group_id, user_id):
+    async def on_POST(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+        user_id: str,
+    ) -> Tuple[int, None]:
         if get_domain_from_id(group_id) != origin:
             raise SynapseError(403, "user_id doesn't match origin")
 
@@ -1172,11 +1427,9 @@ class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet):
             self.handler, GroupsLocalHandler
         ), "Workers cannot handle group removals."
 
-        new_content = await self.handler.user_removed_from_group(
-            group_id, user_id, content
-        )
+        await self.handler.user_removed_from_group(group_id, user_id, content)
 
-        return 200, new_content
+        return 200, None
 
 
 class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
@@ -1194,7 +1447,14 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
         super().__init__(hs, authenticator, ratelimiter, server_name)
         self.handler = hs.get_groups_attestation_renewer()
 
-    async def on_POST(self, origin, content, query, group_id, user_id):
+    async def on_POST(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+        user_id: str,
+    ) -> Tuple[int, JsonDict]:
         # We don't need to check auth here as we check the attestation signatures
 
         new_content = await self.handler.on_renew_attestation(
@@ -1218,7 +1478,15 @@ class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet):
         "/rooms/(?P<room_id>[^/]*)"
     )
 
-    async def on_POST(self, origin, content, query, group_id, category_id, room_id):
+    async def on_POST(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+        category_id: str,
+        room_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1246,7 +1514,15 @@ class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet):
 
         return 200, resp
 
-    async def on_DELETE(self, origin, content, query, group_id, category_id, room_id):
+    async def on_DELETE(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+        category_id: str,
+        room_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1266,7 +1542,13 @@ class FederationGroupsCategoriesServlet(BaseGroupsServerServlet):
 
     PATH = "/groups/(?P<group_id>[^/]*)/categories/?"
 
-    async def on_GET(self, origin, content, query, group_id):
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1281,7 +1563,14 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet):
 
     PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
 
-    async def on_GET(self, origin, content, query, group_id, category_id):
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+        category_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1292,7 +1581,14 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet):
 
         return 200, resp
 
-    async def on_POST(self, origin, content, query, group_id, category_id):
+    async def on_POST(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+        category_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1314,7 +1610,14 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet):
 
         return 200, resp
 
-    async def on_DELETE(self, origin, content, query, group_id, category_id):
+    async def on_DELETE(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+        category_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1334,7 +1637,13 @@ class FederationGroupsRolesServlet(BaseGroupsServerServlet):
 
     PATH = "/groups/(?P<group_id>[^/]*)/roles/?"
 
-    async def on_GET(self, origin, content, query, group_id):
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1349,7 +1658,14 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet):
 
     PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
 
-    async def on_GET(self, origin, content, query, group_id, role_id):
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+        role_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1358,7 +1674,14 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet):
 
         return 200, resp
 
-    async def on_POST(self, origin, content, query, group_id, role_id):
+    async def on_POST(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+        role_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1382,7 +1705,14 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet):
 
         return 200, resp
 
-    async def on_DELETE(self, origin, content, query, group_id, role_id):
+    async def on_DELETE(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+        role_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1411,7 +1741,15 @@ class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet):
         "/users/(?P<user_id>[^/]*)"
     )
 
-    async def on_POST(self, origin, content, query, group_id, role_id, user_id):
+    async def on_POST(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+        role_id: str,
+        user_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1437,7 +1775,15 @@ class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet):
 
         return 200, resp
 
-    async def on_DELETE(self, origin, content, query, group_id, role_id, user_id):
+    async def on_DELETE(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+        role_id: str,
+        user_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1457,7 +1803,9 @@ class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet):
 
     PATH = "/get_groups_publicised"
 
-    async def on_POST(self, origin, content, query):
+    async def on_POST(
+        self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
+    ) -> Tuple[int, JsonDict]:
         resp = await self.handler.bulk_get_publicised_groups(
             content["user_ids"], proxy=False
         )
@@ -1470,7 +1818,13 @@ class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet):
 
     PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"
 
-    async def on_PUT(self, origin, content, query, group_id):
+    async def on_PUT(
+        self,
+        origin: str,
+        content: JsonDict,
+        query: Dict[bytes, List[bytes]],
+        group_id: str,
+    ) -> Tuple[int, JsonDict]:
         requester_user_id = parse_string_from_args(query, "requester_user_id")
         if get_domain_from_id(requester_user_id) != origin:
             raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1499,7 +1853,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet):
     async def on_GET(
         self,
         origin: str,
-        content: JsonDict,
+        content: Literal[None],
         query: Mapping[bytes, Sequence[bytes]],
         room_id: str,
     ) -> Tuple[int, JsonDict]:
@@ -1571,7 +1925,13 @@ class RoomComplexityServlet(BaseFederationServlet):
         super().__init__(hs, authenticator, ratelimiter, server_name)
         self._store = self.hs.get_datastore()
 
-    async def on_GET(self, origin, content, query, room_id):
+    async def on_GET(
+        self,
+        origin: str,
+        content: Literal[None],
+        query: Dict[bytes, List[bytes]],
+        room_id: str,
+    ) -> Tuple[int, JsonDict]:
         is_public = await self._store.is_room_world_readable_or_publicly_joinable(
             room_id
         )
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 1971e373ed..e2ac595a62 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -30,6 +30,7 @@ from typing import (
     Optional,
     Tuple,
     Union,
+    cast,
 )
 
 import attr
@@ -72,6 +73,7 @@ from synapse.util.stringutils import base62_encode
 from synapse.util.threepids import canonicalise_email
 
 if TYPE_CHECKING:
+    from synapse.rest.client.v1.login import LoginResponse
     from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
@@ -777,6 +779,108 @@ class AuthHandler(BaseHandler):
             "params": params,
         }
 
+    async def refresh_token(
+        self,
+        refresh_token: str,
+        valid_until_ms: Optional[int],
+    ) -> Tuple[str, str]:
+        """
+        Consumes a refresh token and generate both a new access token and a new refresh token from it.
+
+        The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token.
+
+        Args:
+            refresh_token: The token to consume.
+            valid_until_ms: The expiration timestamp of the new access token.
+
+        Returns:
+            A tuple containing the new access token and refresh token
+        """
+
+        # Verify the token signature first before looking up the token
+        if not self._verify_refresh_token(refresh_token):
+            raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN)
+
+        existing_token = await self.store.lookup_refresh_token(refresh_token)
+        if existing_token is None:
+            raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN)
+
+        if (
+            existing_token.has_next_access_token_been_used
+            or existing_token.has_next_refresh_token_been_refreshed
+        ):
+            raise SynapseError(
+                403, "refresh token isn't valid anymore", Codes.FORBIDDEN
+            )
+
+        (
+            new_refresh_token,
+            new_refresh_token_id,
+        ) = await self.get_refresh_token_for_user_id(
+            user_id=existing_token.user_id, device_id=existing_token.device_id
+        )
+        access_token = await self.get_access_token_for_user_id(
+            user_id=existing_token.user_id,
+            device_id=existing_token.device_id,
+            valid_until_ms=valid_until_ms,
+            refresh_token_id=new_refresh_token_id,
+        )
+        await self.store.replace_refresh_token(
+            existing_token.token_id, new_refresh_token_id
+        )
+        return access_token, new_refresh_token
+
+    def _verify_refresh_token(self, token: str) -> bool:
+        """
+        Verifies the shape of a refresh token.
+
+        Args:
+            token: The refresh token to verify
+
+        Returns:
+            Whether the token has the right shape
+        """
+        parts = token.split("_", maxsplit=4)
+        if len(parts) != 4:
+            return False
+
+        type, localpart, rand, crc = parts
+
+        # Refresh tokens are prefixed by "syr_", let's check that
+        if type != "syr":
+            return False
+
+        # Check the CRC
+        base = f"{type}_{localpart}_{rand}"
+        expected_crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
+        if crc != expected_crc:
+            return False
+
+        return True
+
+    async def get_refresh_token_for_user_id(
+        self,
+        user_id: str,
+        device_id: str,
+    ) -> Tuple[str, int]:
+        """
+        Creates a new refresh token for the user with the given user ID.
+
+        Args:
+            user_id: canonical user ID
+            device_id: the device ID to associate with the token.
+
+        Returns:
+            The newly created refresh token and its ID in the database
+        """
+        refresh_token = self.generate_refresh_token(UserID.from_string(user_id))
+        refresh_token_id = await self.store.add_refresh_token_to_user(
+            user_id=user_id,
+            token=refresh_token,
+            device_id=device_id,
+        )
+        return refresh_token, refresh_token_id
+
     async def get_access_token_for_user_id(
         self,
         user_id: str,
@@ -784,6 +888,7 @@ class AuthHandler(BaseHandler):
         valid_until_ms: Optional[int],
         puppets_user_id: Optional[str] = None,
         is_appservice_ghost: bool = False,
+        refresh_token_id: Optional[int] = None,
     ) -> str:
         """
         Creates a new access token for the user with the given user ID.
@@ -801,6 +906,8 @@ class AuthHandler(BaseHandler):
             valid_until_ms: when the token is valid until. None for
                 no expiry.
             is_appservice_ghost: Whether the user is an application ghost user
+            refresh_token_id: the refresh token ID that will be associated with
+                this access token.
         Returns:
               The access token for the user's session.
         Raises:
@@ -836,6 +943,7 @@ class AuthHandler(BaseHandler):
             device_id=device_id,
             valid_until_ms=valid_until_ms,
             puppets_user_id=puppets_user_id,
+            refresh_token_id=refresh_token_id,
         )
 
         # the device *should* have been registered before we got here; however,
@@ -928,7 +1036,7 @@ class AuthHandler(BaseHandler):
         self,
         login_submission: Dict[str, Any],
         ratelimit: bool = False,
-    ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
+    ) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
         """Authenticates the user for the /login API
 
         Also used by the user-interactive auth flow to validate auth types which don't
@@ -1073,7 +1181,7 @@ class AuthHandler(BaseHandler):
         self,
         username: str,
         login_submission: Dict[str, Any],
-    ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
+    ) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
         """Helper for validate_login
 
         Handles login, once we've mapped 3pids onto userids
@@ -1151,7 +1259,7 @@ class AuthHandler(BaseHandler):
 
     async def check_password_provider_3pid(
         self, medium: str, address: str, password: str
-    ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
+    ) -> Tuple[Optional[str], Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
         """Check if a password provider is able to validate a thirdparty login
 
         Args:
@@ -1215,6 +1323,19 @@ class AuthHandler(BaseHandler):
         crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
         return f"{base}_{crc}"
 
+    def generate_refresh_token(self, for_user: UserID) -> str:
+        """Generates an opaque string, for use as a refresh token"""
+
+        # we use the following format for refresh tokens:
+        #    syr_<base64 local part>_<random string>_<base62 crc check>
+
+        b64local = unpaddedbase64.encode_base64(for_user.localpart.encode("utf-8"))
+        random_string = stringutils.random_string(20)
+        base = f"syr_{b64local}_{random_string}"
+
+        crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
+        return f"{base}_{crc}"
+
     async def validate_short_term_login_token(
         self, login_token: str
     ) -> LoginTokenAttributes:
@@ -1563,7 +1684,7 @@ class AuthHandler(BaseHandler):
         )
         respond_with_html(request, 200, html)
 
-    async def _sso_login_callback(self, login_result: JsonDict) -> None:
+    async def _sso_login_callback(self, login_result: "LoginResponse") -> None:
         """
         A login callback which might add additional attributes to the login response.
 
@@ -1577,7 +1698,8 @@ class AuthHandler(BaseHandler):
 
         extra_attributes = self._extra_attributes.get(login_result["user_id"])
         if extra_attributes:
-            login_result.update(extra_attributes.extra_attributes)
+            login_result_dict = cast(Dict[str, Any], login_result)
+            login_result_dict.update(extra_attributes.extra_attributes)
 
     def _expire_sso_extra_attributes(self) -> None:
         """
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 1b566dbf2d..d929c65131 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1711,80 +1711,6 @@ class FederationHandler(BaseHandler):
 
         return event
 
-    async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict:
-        """We have received a join event for a room. Fully process it and
-        respond with the current state and auth chains.
-        """
-        event = pdu
-
-        logger.debug(
-            "on_send_join_request from %s: Got event: %s, signatures: %s",
-            origin,
-            event.event_id,
-            event.signatures,
-        )
-
-        if get_domain_from_id(event.sender) != origin:
-            logger.info(
-                "Got /send_join request for user %r from different origin %s",
-                event.sender,
-                origin,
-            )
-            raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
-
-        event.internal_metadata.outlier = False
-        # Send this event on behalf of the origin server.
-        #
-        # The reasons we have the destination server rather than the origin
-        # server send it are slightly mysterious: the origin server should have
-        # all the necessary state once it gets the response to the send_join,
-        # so it could send the event itself if it wanted to. It may be that
-        # doing it this way reduces failure modes, or avoids certain attacks
-        # where a new server selectively tells a subset of the federation that
-        # it has joined.
-        #
-        # The fact is that, as of the current writing, Synapse doesn't send out
-        # the join event over federation after joining, and changing it now
-        # would introduce the danger of backwards-compatibility problems.
-        event.internal_metadata.send_on_behalf_of = origin
-
-        # Calculate the event context.
-        context = await self.state_handler.compute_event_context(event)
-
-        # Get the state before the new event.
-        prev_state_ids = await context.get_prev_state_ids()
-
-        # Check if the user is already in the room or invited to the room.
-        user_id = event.state_key
-        prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
-        prev_member_event = None
-        if prev_member_event_id:
-            prev_member_event = await self.store.get_event(prev_member_event_id)
-
-        # Check if the member should be allowed access via membership in a space.
-        await self._event_auth_handler.check_restricted_join_rules(
-            prev_state_ids,
-            event.room_version,
-            user_id,
-            prev_member_event,
-        )
-
-        # Persist the event.
-        await self._auth_and_persist_event(origin, event, context)
-
-        logger.debug(
-            "on_send_join_request: After _auth_and_persist_event: %s, sigs: %s",
-            event.event_id,
-            event.signatures,
-        )
-
-        state_ids = list(prev_state_ids.values())
-        auth_chain = await self.store.get_auth_chain(event.room_id, state_ids)
-
-        state = await self.store.get_events(list(prev_state_ids.values()))
-
-        return {"state": list(state.values()), "auth_chain": auth_chain}
-
     async def on_invite_request(
         self, origin: str, event: EventBase, room_version: RoomVersion
     ) -> EventBase:
@@ -1960,37 +1886,6 @@ class FederationHandler(BaseHandler):
 
         return event
 
-    async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None:
-        """We have received a leave event for a room. Fully process it."""
-        event = pdu
-
-        logger.debug(
-            "on_send_leave_request: Got event: %s, signatures: %s",
-            event.event_id,
-            event.signatures,
-        )
-
-        if get_domain_from_id(event.sender) != origin:
-            logger.info(
-                "Got /send_leave request for user %r from different origin %s",
-                event.sender,
-                origin,
-            )
-            raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
-
-        event.internal_metadata.outlier = False
-
-        context = await self.state_handler.compute_event_context(event)
-        await self._auth_and_persist_event(origin, event, context)
-
-        logger.debug(
-            "on_send_leave_request: After _auth_and_persist_event: %s, sigs: %s",
-            event.event_id,
-            event.signatures,
-        )
-
-        return None
-
     @log_function
     async def on_make_knock_request(
         self, origin: str, room_id: str, user_id: str
@@ -2054,51 +1949,115 @@ class FederationHandler(BaseHandler):
         return event
 
     @log_function
-    async def on_send_knock_request(
+    async def on_send_membership_event(
         self, origin: str, event: EventBase
     ) -> EventContext:
         """
-        We have received a knock event for a room. Verify that event and send it into the room
-        on the knocking homeserver's behalf.
+        We have received a join/leave/knock event for a room via send_join/leave/knock.
+
+        Verify that event and send it into the room on the remote homeserver's behalf.
+
+        This is quite similar to on_receive_pdu, with the following principal
+        differences:
+          * only membership events are permitted (and only events with
+            sender==state_key -- ie, no kicks or bans)
+          * *We* send out the event on behalf of the remote server.
+          * We enforce the membership restrictions of restricted rooms.
+          * Rejected events result in an exception rather than being stored.
+
+        There are also other differences, however it is not clear if these are by
+        design or omission. In particular, we do not attempt to backfill any missing
+        prev_events.
 
         Args:
-            origin: The remote homeserver of the knocking user.
-            event: The knocking member event that has been signed by the remote homeserver.
+            origin: The homeserver of the remote (joining/invited/knocking) user.
+            event: The member event that has been signed by the remote homeserver.
 
         Returns:
             The context of the event after inserting it into the room graph.
+
+        Raises:
+            SynapseError if the event is not accepted into the room
         """
         logger.debug(
-            "on_send_knock_request: Got event: %s, signatures: %s",
+            "on_send_membership_event: Got event: %s, signatures: %s",
             event.event_id,
             event.signatures,
         )
 
         if get_domain_from_id(event.sender) != origin:
             logger.info(
-                "Got /send_knock request for user %r from different origin %s",
+                "Got send_membership request for user %r from different origin %s",
                 event.sender,
                 origin,
             )
             raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
 
-        event.internal_metadata.outlier = False
+        if event.sender != event.state_key:
+            raise SynapseError(400, "state_key and sender must match", Codes.BAD_JSON)
 
-        context = await self.state_handler.compute_event_context(event)
+        assert not event.internal_metadata.outlier
 
-        event_allowed = await self.third_party_event_rules.check_event_allowed(
-            event, context
-        )
-        if not event_allowed:
-            logger.info("Sending of knock %s forbidden by third-party rules", event)
+        # Send this event on behalf of the other server.
+        #
+        # The remote server isn't a full participant in the room at this point, so
+        # may not have an up-to-date list of the other homeservers participating in
+        # the room, so we send it on their behalf.
+        event.internal_metadata.send_on_behalf_of = origin
+
+        context = await self.state_handler.compute_event_context(event)
+        context = await self._check_event_auth(origin, event, context)
+        if context.rejected:
             raise SynapseError(
-                403, "This event is not allowed in this context", Codes.FORBIDDEN
+                403, f"{event.membership} event was rejected", Codes.FORBIDDEN
             )
 
-        await self._auth_and_persist_event(origin, event, context)
+        # for joins, we need to check the restrictions of restricted rooms
+        if event.membership == Membership.JOIN:
+            await self._check_join_restrictions(context, event)
 
+        # for knock events, we run the third-party event rules. It's not entirely clear
+        # why we don't do this for other sorts of membership events.
+        if event.membership == Membership.KNOCK:
+            event_allowed = await self.third_party_event_rules.check_event_allowed(
+                event, context
+            )
+            if not event_allowed:
+                logger.info("Sending of knock %s forbidden by third-party rules", event)
+                raise SynapseError(
+                    403, "This event is not allowed in this context", Codes.FORBIDDEN
+                )
+
+        # all looks good, we can persist the event.
+        await self._run_push_actions_and_persist_event(event, context)
         return context
 
+    async def _check_join_restrictions(
+        self, context: EventContext, event: EventBase
+    ) -> None:
+        """Check that restrictions in restricted join rules are matched
+
+        Called when we receive a join event via send_join.
+
+        Raises an auth error if the restrictions are not matched.
+        """
+        prev_state_ids = await context.get_prev_state_ids()
+
+        # Check if the user is already in the room or invited to the room.
+        user_id = event.state_key
+        prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
+        prev_member_event = None
+        if prev_member_event_id:
+            prev_member_event = await self.store.get_event(prev_member_event_id)
+
+        # Check if the member should be allowed access via membership in a space.
+        await self._event_auth_handler.check_restricted_join_rules(
+            prev_state_ids,
+            event.room_version,
+            user_id,
+            prev_member_event,
+        )
+
     async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
         """Returns the state at the event. i.e. not including said event."""
 
@@ -2240,6 +2199,18 @@ class FederationHandler(BaseHandler):
             backfilled=backfilled,
         )
 
+        await self._run_push_actions_and_persist_event(event, context, backfilled)
+
+    async def _run_push_actions_and_persist_event(
+        self, event: EventBase, context: EventContext, backfilled: bool = False
+    ):
+        """Run the push actions for a received event, and persist it.
+
+        Args:
+            event: The event itself.
+            context: The event context.
+            backfilled: True if the event was backfilled.
+        """
         try:
             if (
                 not event.internal_metadata.is_outlier()
@@ -2553,9 +2524,9 @@ class FederationHandler(BaseHandler):
         origin: str,
         event: EventBase,
         context: EventContext,
-        state: Optional[Iterable[EventBase]],
-        auth_events: Optional[MutableStateMap[EventBase]],
-        backfilled: bool,
+        state: Optional[Iterable[EventBase]] = None,
+        auth_events: Optional[MutableStateMap[EventBase]] = None,
+        backfilled: bool = False,
     ) -> EventContext:
         """
         Checks whether an event should be rejected (for failing auth checks).
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index ca1ed6a5c0..26ef016179 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -15,9 +15,10 @@
 """Contains functions for registering clients."""
 
 import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 
 from prometheus_client import Counter
+from typing_extensions import TypedDict
 
 from synapse import types
 from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
@@ -54,6 +55,16 @@ login_counter = Counter(
     ["guest", "auth_provider"],
 )
 
+LoginDict = TypedDict(
+    "LoginDict",
+    {
+        "device_id": str,
+        "access_token": str,
+        "valid_until_ms": Optional[int],
+        "refresh_token": Optional[str],
+    },
+)
+
 
 class RegistrationHandler(BaseHandler):
     def __init__(self, hs: "HomeServer"):
@@ -85,6 +96,7 @@ class RegistrationHandler(BaseHandler):
             self.pusher_pool = hs.get_pusherpool()
 
         self.session_lifetime = hs.config.session_lifetime
+        self.access_token_lifetime = hs.config.access_token_lifetime
 
     async def check_username(
         self,
@@ -386,11 +398,32 @@ class RegistrationHandler(BaseHandler):
                 room_alias = RoomAlias.from_string(r)
 
                 if self.hs.hostname != room_alias.domain:
-                    logger.warning(
-                        "Cannot create room alias %s, "
-                        "it does not match server domain",
+                    # If the alias is remote, try to join the room. This might fail
+                    # because the room might be invite only, but we don't have any local
+                    # user in the room to invite this one with, so at this point that's
+                    # the best we can do.
+                    logger.info(
+                        "Cannot automatically create room with alias %s as it isn't"
+                        " local, trying to join the room instead",
                         r,
                     )
+
+                    (
+                        room,
+                        remote_room_hosts,
+                    ) = await room_member_handler.lookup_room_alias(room_alias)
+                    room_id = room.to_string()
+
+                    await room_member_handler.update_membership(
+                        requester=create_requester(
+                            user_id, authenticated_entity=self._server_name
+                        ),
+                        target=UserID.from_string(user_id),
+                        room_id=room_id,
+                        remote_room_hosts=remote_room_hosts,
+                        action="join",
+                        ratelimit=False,
+                    )
                 else:
                     # A shallow copy is OK here since the only key that is
                     # modified is room_alias_name.
@@ -448,22 +481,32 @@ class RegistrationHandler(BaseHandler):
                     )
 
                 # Calculate whether the room requires an invite or can be
-                # joined directly. Note that unless a join rule of public exists,
-                # it is treated as requiring an invite.
-                requires_invite = True
-
-                state = await self.store.get_filtered_current_state_ids(
-                    room_id, StateFilter.from_types([(EventTypes.JoinRules, "")])
+                # joined directly. By default, we consider the room as requiring an
+                # invite if the homeserver is in the room (unless told otherwise by the
+                # join rules). Otherwise we consider it as being joinable, at the risk of
+                # failing to join, but in this case there's little more we can do since
+                # we don't have a local user in the room to craft up an invite with.
+                requires_invite = await self.store.is_host_joined(
+                    room_id,
+                    self.server_name,
                 )
 
-                event_id = state.get((EventTypes.JoinRules, ""))
-                if event_id:
-                    join_rules_event = await self.store.get_event(
-                        event_id, allow_none=True
+                if requires_invite:
+                    # If the server is in the room, check if the room is public.
+                    state = await self.store.get_filtered_current_state_ids(
+                        room_id, StateFilter.from_types([(EventTypes.JoinRules, "")])
                     )
-                    if join_rules_event:
-                        join_rule = join_rules_event.content.get("join_rule", None)
-                        requires_invite = join_rule and join_rule != JoinRules.PUBLIC
+
+                    event_id = state.get((EventTypes.JoinRules, ""))
+                    if event_id:
+                        join_rules_event = await self.store.get_event(
+                            event_id, allow_none=True
+                        )
+                        if join_rules_event:
+                            join_rule = join_rules_event.content.get("join_rule", None)
+                            requires_invite = (
+                                join_rule and join_rule != JoinRules.PUBLIC
+                            )
 
                 # Send the invite, if necessary.
                 if requires_invite:
@@ -665,7 +708,8 @@ class RegistrationHandler(BaseHandler):
         is_guest: bool = False,
         is_appservice_ghost: bool = False,
         auth_provider_id: Optional[str] = None,
-    ) -> Tuple[str, str]:
+        should_issue_refresh_token: bool = False,
+    ) -> Tuple[str, str, Optional[int], Optional[str]]:
         """Register a device for a user and generate an access token.
 
         The access token will be limited by the homeserver's session_lifetime config.
@@ -677,8 +721,9 @@ class RegistrationHandler(BaseHandler):
             is_guest: Whether this is a guest account
             auth_provider_id: The SSO IdP the user used, if any (just used for the
                 prometheus metrics).
+            should_issue_refresh_token: Whether it should also issue a refresh token
         Returns:
-            Tuple of device ID and access token
+            Tuple of device ID, access token, access token expiration time and refresh token
         """
         res = await self._register_device_client(
             user_id=user_id,
@@ -686,6 +731,7 @@ class RegistrationHandler(BaseHandler):
             initial_display_name=initial_display_name,
             is_guest=is_guest,
             is_appservice_ghost=is_appservice_ghost,
+            should_issue_refresh_token=should_issue_refresh_token,
         )
 
         login_counter.labels(
@@ -693,7 +739,12 @@ class RegistrationHandler(BaseHandler):
             auth_provider=(auth_provider_id or ""),
         ).inc()
 
-        return res["device_id"], res["access_token"]
+        return (
+            res["device_id"],
+            res["access_token"],
+            res["valid_until_ms"],
+            res["refresh_token"],
+        )
 
     async def register_device_inner(
         self,
@@ -702,7 +753,8 @@ class RegistrationHandler(BaseHandler):
         initial_display_name: Optional[str],
         is_guest: bool = False,
         is_appservice_ghost: bool = False,
-    ) -> Dict[str, str]:
+        should_issue_refresh_token: bool = False,
+    ) -> LoginDict:
         """Helper for register_device
 
         Does the bits that need doing on the main process. Not for use outside this
@@ -717,6 +769,9 @@ class RegistrationHandler(BaseHandler):
                 )
             valid_until_ms = self.clock.time_msec() + self.session_lifetime
 
+        refresh_token = None
+        refresh_token_id = None
+
         registered_device_id = await self.device_handler.check_device_registered(
             user_id, device_id, initial_display_name
         )
@@ -724,14 +779,30 @@ class RegistrationHandler(BaseHandler):
             assert valid_until_ms is None
             access_token = self.macaroon_gen.generate_guest_access_token(user_id)
         else:
+            if should_issue_refresh_token:
+                (
+                    refresh_token,
+                    refresh_token_id,
+                ) = await self._auth_handler.get_refresh_token_for_user_id(
+                    user_id,
+                    device_id=registered_device_id,
+                )
+                valid_until_ms = self.clock.time_msec() + self.access_token_lifetime
+
             access_token = await self._auth_handler.get_access_token_for_user_id(
                 user_id,
                 device_id=registered_device_id,
                 valid_until_ms=valid_until_ms,
                 is_appservice_ghost=is_appservice_ghost,
+                refresh_token_id=refresh_token_id,
             )
 
-        return {"device_id": registered_device_id, "access_token": access_token}
+        return {
+            "device_id": registered_device_id,
+            "access_token": access_token,
+            "valid_until_ms": valid_until_ms,
+            "refresh_token": refresh_token,
+        }
 
     async def post_registration_actions(
         self, user_id: str, auth_result: dict, access_token: Optional[str]
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 845651e606..efbc6d5b25 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -728,7 +728,7 @@ def set_cors_headers(request: Request):
     )
     request.setHeader(
         b"Access-Control-Allow-Headers",
-        b"Origin, X-Requested-With, Content-Type, Accept, Authorization, Date",
+        b"X-Requested-With, Content-Type, Authorization, Date",
     )
 
 
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index fda8da21b7..6ba2ce1e53 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -113,8 +113,18 @@ def parse_boolean_from_args(args, name, default=None, required=False):
 def parse_bytes_from_args(
     args: Dict[bytes, List[bytes]],
     name: str,
+    default: Optional[bytes] = None,
+) -> Optional[bytes]:
+    ...
+
+
+@overload
+def parse_bytes_from_args(
+    args: Dict[bytes, List[bytes]],
+    name: str,
     default: Literal[None] = None,
-    required: Literal[True] = True,
+    *,
+    required: Literal[True],
 ) -> bytes:
     ...
 
@@ -197,7 +207,12 @@ def parse_string(
     """
     args = request.args  # type: Dict[bytes, List[bytes]]  # type: ignore
     return parse_string_from_args(
-        args, name, default, required, allowed_values, encoding
+        args,
+        name,
+        default,
+        required=required,
+        allowed_values=allowed_values,
+        encoding=encoding,
     )
 
 
@@ -227,7 +242,20 @@ def parse_strings_from_args(
     args: Dict[bytes, List[bytes]],
     name: str,
     default: Optional[List[str]] = None,
-    required: Literal[True] = True,
+    *,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: str = "ascii",
+) -> Optional[List[str]]:
+    ...
+
+
+@overload
+def parse_strings_from_args(
+    args: Dict[bytes, List[bytes]],
+    name: str,
+    default: Optional[List[str]] = None,
+    *,
+    required: Literal[True],
     allowed_values: Optional[Iterable[str]] = None,
     encoding: str = "ascii",
 ) -> List[str]:
@@ -239,6 +267,7 @@ def parse_strings_from_args(
     args: Dict[bytes, List[bytes]],
     name: str,
     default: Optional[List[str]] = None,
+    *,
     required: bool = False,
     allowed_values: Optional[Iterable[str]] = None,
     encoding: str = "ascii",
@@ -299,7 +328,20 @@ def parse_string_from_args(
     args: Dict[bytes, List[bytes]],
     name: str,
     default: Optional[str] = None,
-    required: Literal[True] = True,
+    *,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: str = "ascii",
+) -> Optional[str]:
+    ...
+
+
+@overload
+def parse_string_from_args(
+    args: Dict[bytes, List[bytes]],
+    name: str,
+    default: Optional[str] = None,
+    *,
+    required: Literal[True],
     allowed_values: Optional[Iterable[str]] = None,
     encoding: str = "ascii",
 ) -> str:
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 58b255eb1b..721c45abac 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -168,7 +168,7 @@ class ModuleApi:
             "Using deprecated ModuleApi.register which creates a dummy user device."
         )
         user_id = yield self.register_user(localpart, displayname, emails or [])
-        _, access_token = yield self.register_device(user_id)
+        _, access_token, _, _ = yield self.register_device(user_id)
         return user_id, access_token
 
     def register_user(
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index c2e8c00293..550bd5c95f 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -36,20 +36,29 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
 
     @staticmethod
     async def _serialize_payload(
-        user_id, device_id, initial_display_name, is_guest, is_appservice_ghost
+        user_id,
+        device_id,
+        initial_display_name,
+        is_guest,
+        is_appservice_ghost,
+        should_issue_refresh_token,
     ):
         """
         Args:
+            user_id (int)
             device_id (str|None): Device ID to use, if None a new one is
                 generated.
             initial_display_name (str|None)
             is_guest (bool)
+            is_appservice_ghost (bool)
+            should_issue_refresh_token (bool)
         """
         return {
             "device_id": device_id,
             "initial_display_name": initial_display_name,
             "is_guest": is_guest,
             "is_appservice_ghost": is_appservice_ghost,
+            "should_issue_refresh_token": should_issue_refresh_token,
         }
 
     async def _handle_request(self, request, user_id):
@@ -59,6 +68,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
         initial_display_name = content["initial_display_name"]
         is_guest = content["is_guest"]
         is_appservice_ghost = content["is_appservice_ghost"]
+        should_issue_refresh_token = content["should_issue_refresh_token"]
 
         res = await self.registration_handler.register_device_inner(
             user_id,
@@ -66,6 +76,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
             initial_display_name,
             is_guest,
             is_appservice_ghost=is_appservice_ghost,
+            should_issue_refresh_token=should_issue_refresh_token,
         )
 
         return 200, res
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index f6be5f1020..cbcb60fe31 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,7 +14,9 @@
 
 import logging
 import re
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional
+
+from typing_extensions import TypedDict
 
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
@@ -25,6 +27,8 @@ from synapse.http import get_request_uri
 from synapse.http.server import HttpServer, finish_request
 from synapse.http.servlet import (
     RestServlet,
+    assert_params_in_dict,
+    parse_boolean,
     parse_bytes_from_args,
     parse_json_object_from_request,
     parse_string,
@@ -40,6 +44,21 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+LoginResponse = TypedDict(
+    "LoginResponse",
+    {
+        "user_id": str,
+        "access_token": str,
+        "home_server": str,
+        "expires_in_ms": Optional[int],
+        "refresh_token": Optional[str],
+        "device_id": str,
+        "well_known": Optional[Dict[str, Any]],
+    },
+    total=False,
+)
+
+
 class LoginRestServlet(RestServlet):
     PATTERNS = client_patterns("/login$", v1=True)
     CAS_TYPE = "m.login.cas"
@@ -48,6 +67,7 @@ class LoginRestServlet(RestServlet):
     JWT_TYPE = "org.matrix.login.jwt"
     JWT_TYPE_DEPRECATED = "m.login.jwt"
     APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
+    REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token"
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
@@ -65,9 +85,12 @@ class LoginRestServlet(RestServlet):
         self.cas_enabled = hs.config.cas_enabled
         self.oidc_enabled = hs.config.oidc_enabled
         self._msc2858_enabled = hs.config.experimental.msc2858_enabled
+        self._msc2918_enabled = hs.config.access_token_lifetime is not None
 
         self.auth = hs.get_auth()
 
+        self.clock = hs.get_clock()
+
         self.auth_handler = self.hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
         self._sso_handler = hs.get_sso_handler()
@@ -138,6 +161,15 @@ class LoginRestServlet(RestServlet):
     async def on_POST(self, request: SynapseRequest):
         login_submission = parse_json_object_from_request(request)
 
+        if self._msc2918_enabled:
+            # Check if this login should also issue a refresh token, as per
+            # MSC2918
+            should_issue_refresh_token = parse_boolean(
+                request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False
+            )
+        else:
+            should_issue_refresh_token = False
+
         try:
             if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
                 appservice = self.auth.get_appservice_by_req(request)
@@ -147,19 +179,32 @@ class LoginRestServlet(RestServlet):
                         None, request.getClientIP()
                     )
 
-                result = await self._do_appservice_login(login_submission, appservice)
+                result = await self._do_appservice_login(
+                    login_submission,
+                    appservice,
+                    should_issue_refresh_token=should_issue_refresh_token,
+                )
             elif self.jwt_enabled and (
                 login_submission["type"] == LoginRestServlet.JWT_TYPE
                 or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
             ):
                 await self._address_ratelimiter.ratelimit(None, request.getClientIP())
-                result = await self._do_jwt_login(login_submission)
+                result = await self._do_jwt_login(
+                    login_submission,
+                    should_issue_refresh_token=should_issue_refresh_token,
+                )
             elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
                 await self._address_ratelimiter.ratelimit(None, request.getClientIP())
-                result = await self._do_token_login(login_submission)
+                result = await self._do_token_login(
+                    login_submission,
+                    should_issue_refresh_token=should_issue_refresh_token,
+                )
             else:
                 await self._address_ratelimiter.ratelimit(None, request.getClientIP())
-                result = await self._do_other_login(login_submission)
+                result = await self._do_other_login(
+                    login_submission,
+                    should_issue_refresh_token=should_issue_refresh_token,
+                )
         except KeyError:
             raise SynapseError(400, "Missing JSON keys.")
 
@@ -169,7 +214,10 @@ class LoginRestServlet(RestServlet):
         return 200, result
 
     async def _do_appservice_login(
-        self, login_submission: JsonDict, appservice: ApplicationService
+        self,
+        login_submission: JsonDict,
+        appservice: ApplicationService,
+        should_issue_refresh_token: bool = False,
     ):
         identifier = login_submission.get("identifier")
         logger.info("Got appservice login request with identifier: %r", identifier)
@@ -198,14 +246,21 @@ class LoginRestServlet(RestServlet):
             raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
 
         return await self._complete_login(
-            qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
+            qualified_user_id,
+            login_submission,
+            ratelimit=appservice.is_rate_limited(),
+            should_issue_refresh_token=should_issue_refresh_token,
         )
 
-    async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
+    async def _do_other_login(
+        self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+    ) -> LoginResponse:
         """Handle non-token/saml/jwt logins
 
         Args:
             login_submission:
+            should_issue_refresh_token: True if this login should issue
+                a refresh token alongside the access token.
 
         Returns:
             HTTP response
@@ -224,7 +279,10 @@ class LoginRestServlet(RestServlet):
             login_submission, ratelimit=True
         )
         result = await self._complete_login(
-            canonical_user_id, login_submission, callback
+            canonical_user_id,
+            login_submission,
+            callback,
+            should_issue_refresh_token=should_issue_refresh_token,
         )
         return result
 
@@ -232,11 +290,12 @@ class LoginRestServlet(RestServlet):
         self,
         user_id: str,
         login_submission: JsonDict,
-        callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
+        callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None,
         create_non_existent_users: bool = False,
         ratelimit: bool = True,
         auth_provider_id: Optional[str] = None,
-    ) -> Dict[str, str]:
+        should_issue_refresh_token: bool = False,
+    ) -> LoginResponse:
         """Called when we've successfully authed the user and now need to
         actually login them in (e.g. create devices). This gets called on
         all successful logins.
@@ -253,6 +312,8 @@ class LoginRestServlet(RestServlet):
             ratelimit: Whether to ratelimit the login request.
             auth_provider_id: The SSO IdP the user used, if any (just used for the
                 prometheus metrics).
+            should_issue_refresh_token: True if this login should issue
+                a refresh token alongside the access token.
 
         Returns:
             result: Dictionary of account information after successful login.
@@ -274,28 +335,48 @@ class LoginRestServlet(RestServlet):
 
         device_id = login_submission.get("device_id")
         initial_display_name = login_submission.get("initial_device_display_name")
-        device_id, access_token = await self.registration_handler.register_device(
-            user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id
+        (
+            device_id,
+            access_token,
+            valid_until_ms,
+            refresh_token,
+        ) = await self.registration_handler.register_device(
+            user_id,
+            device_id,
+            initial_display_name,
+            auth_provider_id=auth_provider_id,
+            should_issue_refresh_token=should_issue_refresh_token,
         )
 
-        result = {
-            "user_id": user_id,
-            "access_token": access_token,
-            "home_server": self.hs.hostname,
-            "device_id": device_id,
-        }
+        result = LoginResponse(
+            user_id=user_id,
+            access_token=access_token,
+            home_server=self.hs.hostname,
+            device_id=device_id,
+        )
+
+        if valid_until_ms is not None:
+            expires_in_ms = valid_until_ms - self.clock.time_msec()
+            result["expires_in_ms"] = expires_in_ms
+
+        if refresh_token is not None:
+            result["refresh_token"] = refresh_token
 
         if callback is not None:
             await callback(result)
 
         return result
 
-    async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
+    async def _do_token_login(
+        self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+    ) -> LoginResponse:
         """
         Handle the final stage of SSO login.
 
         Args:
-             login_submission: The JSON request body.
+            login_submission: The JSON request body.
+            should_issue_refresh_token: True if this login should issue
+                a refresh token alongside the access token.
 
         Returns:
             The body of the JSON response.
@@ -309,9 +390,12 @@ class LoginRestServlet(RestServlet):
             login_submission,
             self.auth_handler._sso_login_callback,
             auth_provider_id=res.auth_provider_id,
+            should_issue_refresh_token=should_issue_refresh_token,
         )
 
-    async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
+    async def _do_jwt_login(
+        self, login_submission: JsonDict, should_issue_refresh_token: bool = False
+    ) -> LoginResponse:
         token = login_submission.get("token", None)
         if token is None:
             raise LoginError(
@@ -342,7 +426,10 @@ class LoginRestServlet(RestServlet):
 
         user_id = UserID(user, self.hs.hostname).to_string()
         result = await self._complete_login(
-            user_id, login_submission, create_non_existent_users=True
+            user_id,
+            login_submission,
+            create_non_existent_users=True,
+            should_issue_refresh_token=should_issue_refresh_token,
         )
         return result
 
@@ -371,6 +458,42 @@ def _get_auth_flow_dict_for_idp(
     return e
 
 
+class RefreshTokenServlet(RestServlet):
+    PATTERNS = client_patterns(
+        "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True
+    )
+
+    def __init__(self, hs: "HomeServer"):
+        self._auth_handler = hs.get_auth_handler()
+        self._clock = hs.get_clock()
+        self.access_token_lifetime = hs.config.access_token_lifetime
+
+    async def on_POST(
+        self,
+        request: SynapseRequest,
+    ):
+        refresh_submission = parse_json_object_from_request(request)
+
+        assert_params_in_dict(refresh_submission, ["refresh_token"])
+        token = refresh_submission["refresh_token"]
+        if not isinstance(token, str):
+            raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM)
+
+        valid_until_ms = self._clock.time_msec() + self.access_token_lifetime
+        access_token, refresh_token = await self._auth_handler.refresh_token(
+            token, valid_until_ms
+        )
+        expires_in_ms = valid_until_ms - self._clock.time_msec()
+        return (
+            200,
+            {
+                "access_token": access_token,
+                "refresh_token": refresh_token,
+                "expires_in_ms": expires_in_ms,
+            },
+        )
+
+
 class SsoRedirectServlet(RestServlet):
     PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
         re.compile(
@@ -477,6 +600,8 @@ class CasTicketServlet(RestServlet):
 
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
+    if hs.config.access_token_lifetime is not None:
+        RefreshTokenServlet(hs).register(http_server)
     SsoRedirectServlet(hs).register(http_server)
     if hs.config.cas_enabled:
         CasTicketServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index a30a5df1b1..4d31584acd 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -41,11 +41,13 @@ from synapse.http.server import finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
+    parse_boolean,
     parse_json_object_from_request,
     parse_string,
 )
 from synapse.metrics import threepid_send_requests
 from synapse.push.mailer import Mailer
+from synapse.types import JsonDict
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -399,6 +401,7 @@ class RegisterRestServlet(RestServlet):
         self.password_policy_handler = hs.get_password_policy_handler()
         self.clock = hs.get_clock()
         self._registration_enabled = self.hs.config.enable_registration
+        self._msc2918_enabled = hs.config.access_token_lifetime is not None
 
         self._registration_flows = _calculate_registration_flows(
             hs.config, self.auth_handler
@@ -424,6 +427,15 @@ class RegisterRestServlet(RestServlet):
                 "Do not understand membership kind: %s" % (kind.decode("utf8"),)
             )
 
+        if self._msc2918_enabled:
+            # Check if this registration should also issue a refresh token, as
+            # per MSC2918
+            should_issue_refresh_token = parse_boolean(
+                request, name="org.matrix.msc2918.refresh_token", default=False
+            )
+        else:
+            should_issue_refresh_token = False
+
         # Pull out the provided username and do basic sanity checks early since
         # the auth layer will store these in sessions.
         desired_username = None
@@ -462,7 +474,10 @@ class RegisterRestServlet(RestServlet):
                 raise SynapseError(400, "Desired Username is missing or not a string")
 
             result = await self._do_appservice_registration(
-                desired_username, access_token, body
+                desired_username,
+                access_token,
+                body,
+                should_issue_refresh_token=should_issue_refresh_token,
             )
 
             return 200, result
@@ -665,7 +680,9 @@ class RegisterRestServlet(RestServlet):
             registered = True
 
         return_dict = await self._create_registration_details(
-            registered_user_id, params
+            registered_user_id,
+            params,
+            should_issue_refresh_token=should_issue_refresh_token,
         )
 
         if registered:
@@ -677,7 +694,9 @@ class RegisterRestServlet(RestServlet):
 
         return 200, return_dict
 
-    async def _do_appservice_registration(self, username, as_token, body):
+    async def _do_appservice_registration(
+        self, username, as_token, body, should_issue_refresh_token: bool = False
+    ):
         user_id = await self.registration_handler.appservice_register(
             username, as_token
         )
@@ -685,19 +704,27 @@ class RegisterRestServlet(RestServlet):
             user_id,
             body,
             is_appservice_ghost=True,
+            should_issue_refresh_token=should_issue_refresh_token,
         )
 
     async def _create_registration_details(
-        self, user_id, params, is_appservice_ghost=False
+        self,
+        user_id: str,
+        params: JsonDict,
+        is_appservice_ghost: bool = False,
+        should_issue_refresh_token: bool = False,
     ):
         """Complete registration of newly-registered user
 
         Allocates device_id if one was not given; also creates access_token.
 
         Args:
-            (str) user_id: full canonical @user:id
-            (object) params: registration parameters, from which we pull
-                device_id, initial_device_name and inhibit_login
+            user_id: full canonical @user:id
+            params: registration parameters, from which we pull device_id,
+                initial_device_name and inhibit_login
+            is_appservice_ghost
+            should_issue_refresh_token: True if this registration should issue
+                a refresh token alongside the access token.
         Returns:
              dictionary for response from /register
         """
@@ -705,15 +732,29 @@ class RegisterRestServlet(RestServlet):
         if not params.get("inhibit_login", False):
             device_id = params.get("device_id")
             initial_display_name = params.get("initial_device_display_name")
-            device_id, access_token = await self.registration_handler.register_device(
+            (
+                device_id,
+                access_token,
+                valid_until_ms,
+                refresh_token,
+            ) = await self.registration_handler.register_device(
                 user_id,
                 device_id,
                 initial_display_name,
                 is_guest=False,
                 is_appservice_ghost=is_appservice_ghost,
+                should_issue_refresh_token=should_issue_refresh_token,
             )
 
             result.update({"access_token": access_token, "device_id": device_id})
+
+            if valid_until_ms is not None:
+                expires_in_ms = valid_until_ms - self.clock.time_msec()
+                result["expires_in_ms"] = expires_in_ms
+
+            if refresh_token is not None:
+                result["refresh_token"] = refresh_token
+
         return result
 
     async def _do_guest_registration(self, params, address=None):
@@ -727,19 +768,30 @@ class RegisterRestServlet(RestServlet):
         # we have nowhere to store it.
         device_id = synapse.api.auth.GUEST_DEVICE_ID
         initial_display_name = params.get("initial_device_display_name")
-        device_id, access_token = await self.registration_handler.register_device(
+        (
+            device_id,
+            access_token,
+            valid_until_ms,
+            refresh_token,
+        ) = await self.registration_handler.register_device(
             user_id, device_id, initial_display_name, is_guest=True
         )
 
-        return (
-            200,
-            {
-                "user_id": user_id,
-                "device_id": device_id,
-                "access_token": access_token,
-                "home_server": self.hs.hostname,
-            },
-        )
+        result = {
+            "user_id": user_id,
+            "device_id": device_id,
+            "access_token": access_token,
+            "home_server": self.hs.hostname,
+        }
+
+        if valid_until_ms is not None:
+            expires_in_ms = valid_until_ms - self.clock.time_msec()
+            result["expires_in_ms"] = expires_in_ms
+
+        if refresh_token is not None:
+            result["refresh_token"] = refresh_token
+
+        return 200, result
 
 
 def _calculate_registration_flows(
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 042e1788b6..ecbbcf3851 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 import itertools
 import logging
+from collections import defaultdict
 from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
 
 from synapse.api.constants import Membership, PresenceState
@@ -232,29 +233,51 @@ class SyncRestServlet(RestServlet):
         )
 
         logger.debug("building sync response dict")
-        return {
-            "account_data": {"events": sync_result.account_data},
-            "to_device": {"events": sync_result.to_device},
-            "device_lists": {
-                "changed": list(sync_result.device_lists.changed),
-                "left": list(sync_result.device_lists.left),
-            },
-            "presence": SyncRestServlet.encode_presence(sync_result.presence, time_now),
-            "rooms": {
-                Membership.JOIN: joined,
-                Membership.INVITE: invited,
-                Membership.KNOCK: knocked,
-                Membership.LEAVE: archived,
-            },
-            "groups": {
-                Membership.JOIN: sync_result.groups.join,
-                Membership.INVITE: sync_result.groups.invite,
-                Membership.LEAVE: sync_result.groups.leave,
-            },
-            "device_one_time_keys_count": sync_result.device_one_time_keys_count,
-            "org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types,
-            "next_batch": await sync_result.next_batch.to_string(self.store),
-        }
+
+        response: dict = defaultdict(dict)
+        response["next_batch"] = await sync_result.next_batch.to_string(self.store)
+
+        if sync_result.account_data:
+            response["account_data"] = {"events": sync_result.account_data}
+        if sync_result.presence:
+            response["presence"] = SyncRestServlet.encode_presence(
+                sync_result.presence, time_now
+            )
+
+        if sync_result.to_device:
+            response["to_device"] = {"events": sync_result.to_device}
+
+        if sync_result.device_lists.changed:
+            response["device_lists"]["changed"] = list(sync_result.device_lists.changed)
+        if sync_result.device_lists.left:
+            response["device_lists"]["left"] = list(sync_result.device_lists.left)
+
+        if sync_result.device_one_time_keys_count:
+            response[
+                "device_one_time_keys_count"
+            ] = sync_result.device_one_time_keys_count
+        if sync_result.device_unused_fallback_key_types:
+            response[
+                "org.matrix.msc2732.device_unused_fallback_key_types"
+            ] = sync_result.device_unused_fallback_key_types
+
+        if joined:
+            response["rooms"][Membership.JOIN] = joined
+        if invited:
+            response["rooms"][Membership.INVITE] = invited
+        if knocked:
+            response["rooms"][Membership.KNOCK] = knocked
+        if archived:
+            response["rooms"][Membership.LEAVE] = archived
+
+        if sync_result.groups.join:
+            response["groups"][Membership.JOIN] = sync_result.groups.join
+        if sync_result.groups.invite:
+            response["groups"][Membership.INVITE] = sync_result.groups.invite
+        if sync_result.groups.leave:
+            response["groups"][Membership.LEAVE] = sync_result.groups.leave
+
+        return response
 
     @staticmethod
     def encode_presence(events, time_now):
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e5c5cf8ff0..e31c5864ac 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -53,6 +53,9 @@ class TokenLookupResult:
         valid_until_ms: The timestamp the token expires, if any.
         token_owner: The "owner" of the token. This is either the same as the
             user, or a server admin who is logged in as the user.
+        token_used: True if this token was used at least once in a request.
+            This field can be out of date since `get_user_by_access_token` is
+            cached.
     """
 
     user_id = attr.ib(type=str)
@@ -62,6 +65,7 @@ class TokenLookupResult:
     device_id = attr.ib(type=Optional[str], default=None)
     valid_until_ms = attr.ib(type=Optional[int], default=None)
     token_owner = attr.ib(type=str)
+    token_used = attr.ib(type=bool, default=False)
 
     # Make the token owner default to the user ID, which is the common case.
     @token_owner.default
@@ -69,6 +73,29 @@ class TokenLookupResult:
         return self.user_id
 
 
+@attr.s(frozen=True, slots=True)
+class RefreshTokenLookupResult:
+    """Result of looking up a refresh token."""
+
+    user_id = attr.ib(type=str)
+    """The user this token belongs to."""
+
+    device_id = attr.ib(type=str)
+    """The device associated with this refresh token."""
+
+    token_id = attr.ib(type=int)
+    """The ID of this refresh token."""
+
+    next_token_id = attr.ib(type=Optional[int])
+    """The ID of the refresh token which replaced this one."""
+
+    has_next_refresh_token_been_refreshed = attr.ib(type=bool)
+    """True if the next refresh token was used for another refresh."""
+
+    has_next_access_token_been_used = attr.ib(type=bool)
+    """True if the next access token was already used at least once."""
+
+
 class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     def __init__(
         self,
@@ -441,7 +468,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 access_tokens.id as token_id,
                 access_tokens.device_id,
                 access_tokens.valid_until_ms,
-                access_tokens.user_id as token_owner
+                access_tokens.user_id as token_owner,
+                access_tokens.used as token_used
             FROM users
             INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
             WHERE token = ?
@@ -449,8 +477,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         txn.execute(sql, (token,))
         rows = self.db_pool.cursor_to_dict(txn)
+
         if rows:
-            return TokenLookupResult(**rows[0])
+            row = rows[0]
+
+            # This field is nullable, ensure it comes out as a boolean
+            if row["token_used"] is None:
+                row["token_used"] = False
+
+            return TokenLookupResult(**row)
 
         return None
 
@@ -1072,6 +1107,111 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             desc="update_access_token_last_validated",
         )
 
+    @cached()
+    async def mark_access_token_as_used(self, token_id: int) -> None:
+        """
+        Mark the access token as used, which invalidates the refresh token used
+        to obtain it.
+
+        Because get_user_by_access_token is cached, this function might be
+        called multiple times for the same token, effectively doing unnecessary
+        SQL updates. Because updating the `used` field only goes one way (from
+        False to True) it is safe to cache this function as well to avoid this
+        issue.
+
+        Args:
+            token_id: The ID of the access token to update.
+        Raises:
+            StoreError if there was a problem updating this.
+        """
+        await self.db_pool.simple_update_one(
+            "access_tokens",
+            {"id": token_id},
+            {"used": True},
+            desc="mark_access_token_as_used",
+        )
+
+    async def lookup_refresh_token(
+        self, token: str
+    ) -> Optional[RefreshTokenLookupResult]:
+        """Lookup a refresh token with hints about its validity."""
+
+        def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]:
+            txn.execute(
+                """
+                SELECT
+                    rt.id token_id,
+                    rt.user_id,
+                    rt.device_id,
+                    rt.next_token_id,
+                    (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed,
+                    at.used has_next_access_token_been_used
+                FROM refresh_tokens rt
+                LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
+                LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
+                WHERE rt.token = ?
+            """,
+                (token,),
+            )
+            row = txn.fetchone()
+
+            if row is None:
+                return None
+
+            return RefreshTokenLookupResult(
+                token_id=row[0],
+                user_id=row[1],
+                device_id=row[2],
+                next_token_id=row[3],
+                has_next_refresh_token_been_refreshed=row[4],
+                # This column is nullable, ensure it's a boolean
+                has_next_access_token_been_used=(row[5] or False),
+            )
+
+        return await self.db_pool.runInteraction(
+            "lookup_refresh_token", _lookup_refresh_token_txn
+        )
+
+    async def replace_refresh_token(self, token_id: int, next_token_id: int) -> None:
+        """
+        Set the successor of a refresh token, removing the existing successor
+        if any.
+
+        Args:
+            token_id: ID of the refresh token to update.
+            next_token_id: ID of its successor.
+        """
+
+        def _replace_refresh_token_txn(txn) -> None:
+            # First check if there was an existing refresh token
+            old_next_token_id = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                "refresh_tokens",
+                {"id": token_id},
+                "next_token_id",
+                allow_none=True,
+            )
+
+            self.db_pool.simple_update_one_txn(
+                txn,
+                "refresh_tokens",
+                {"id": token_id},
+                {"next_token_id": next_token_id},
+            )
+
+            # Delete the old "next" token if it exists. This should cascade and
+            # delete the associated access_token
+            if old_next_token_id is not None:
+                self.db_pool.simple_delete_one_txn(
+                    txn,
+                    "refresh_tokens",
+                    {"id": old_next_token_id},
+                )
+
+        await self.db_pool.runInteraction(
+            "replace_refresh_token", _replace_refresh_token_txn
+        )
+
 
 class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
     def __init__(
@@ -1263,6 +1403,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
 
         self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
+        self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
 
     async def add_access_token_to_user(
         self,
@@ -1271,14 +1412,18 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         device_id: Optional[str],
         valid_until_ms: Optional[int],
         puppets_user_id: Optional[str] = None,
+        refresh_token_id: Optional[int] = None,
     ) -> int:
         """Adds an access token for the given user.
 
         Args:
             user_id: The user ID.
             token: The new access token to add.
-            device_id: ID of the device to associate with the access token
+            device_id: ID of the device to associate with the access token.
             valid_until_ms: when the token is valid until. None for no expiry.
+            puppets_user_id
+            refresh_token_id: ID of the refresh token generated alongside this
+                access token.
         Raises:
             StoreError if there was a problem adding this.
         Returns:
@@ -1297,12 +1442,47 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
                 "valid_until_ms": valid_until_ms,
                 "puppets_user_id": puppets_user_id,
                 "last_validated": now,
+                "refresh_token_id": refresh_token_id,
+                "used": False,
             },
             desc="add_access_token_to_user",
         )
 
         return next_id
 
+    async def add_refresh_token_to_user(
+        self,
+        user_id: str,
+        token: str,
+        device_id: Optional[str],
+    ) -> int:
+        """Adds a refresh token for the given user.
+
+        Args:
+            user_id: The user ID.
+            token: The new access token to add.
+            device_id: ID of the device to associate with the refresh token.
+        Raises:
+            StoreError if there was a problem adding this.
+        Returns:
+            The token ID
+        """
+        next_id = self._refresh_tokens_id_gen.get_next()
+
+        await self.db_pool.simple_insert(
+            "refresh_tokens",
+            {
+                "id": next_id,
+                "user_id": user_id,
+                "device_id": device_id,
+                "token": token,
+                "next_token_id": None,
+            },
+            desc="add_refresh_token_to_user",
+        )
+
+        return next_id
+
     def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
         old_device_id = self.db_pool.simple_select_one_onecol_txn(
             txn, "access_tokens", {"token": token}, "device_id"
@@ -1545,7 +1725,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         device_id: Optional[str] = None,
     ) -> List[Tuple[str, int, Optional[str]]]:
         """
-        Invalidate access tokens belonging to a user
+        Invalidate access and refresh tokens belonging to a user
 
         Args:
             user_id: ID of user the tokens belong to
@@ -1565,7 +1745,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             items = keyvalues.items()
             where_clause = " AND ".join(k + " = ?" for k, _ in items)
             values = [v for _, v in items]  # type: List[Union[str, int]]
+            # Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
+            # is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
+            # clause and values before we handle that. This seems to be only used in the "set password" handler.
+            refresh_where_clause = where_clause
+            refresh_values = values.copy()
             if except_token_id:
+                # TODO: support that for refresh tokens
                 where_clause += " AND id != ?"
                 values.append(except_token_id)
 
@@ -1583,6 +1769,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
             txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
 
+            txn.execute(
+                "DELETE FROM refresh_tokens WHERE %s" % refresh_where_clause,
+                refresh_values,
+            )
+
             return tokens_and_devices
 
         return await self.db_pool.runInteraction("user_delete_access_tokens", f)
@@ -1599,6 +1790,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
         await self.db_pool.runInteraction("delete_access_token", f)
 
+    async def delete_refresh_token(self, refresh_token: str) -> None:
+        def f(txn):
+            self.db_pool.simple_delete_one_txn(
+                txn, table="refresh_tokens", keyvalues={"token": refresh_token}
+            )
+
+        await self.db_pool.runInteraction("delete_refresh_token", f)
+
     async def add_user_pending_deactivation(self, user_id: str) -> None:
         """
         Adds a user to the table of users who need to be parted from all the rooms they're
diff --git a/synapse/storage/schema/main/delta/59/14refresh_tokens.sql b/synapse/storage/schema/main/delta/59/14refresh_tokens.sql
new file mode 100644
index 0000000000..9a6bce1e3e
--- /dev/null
+++ b/synapse/storage/schema/main/delta/59/14refresh_tokens.sql
@@ -0,0 +1,34 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Holds MSC2918 refresh tokens
+CREATE TABLE refresh_tokens (
+  id BIGINT PRIMARY KEY,
+  user_id TEXT NOT NULL,
+  device_id TEXT NOT NULL,
+  token TEXT NOT NULL,
+  -- When consumed, a new refresh token is generated, which is tracked by
+  -- this foreign key
+  next_token_id BIGINT REFERENCES refresh_tokens (id) ON DELETE CASCADE,
+  UNIQUE(token)
+);
+
+-- Add a reference to the refresh token generated alongside each access token
+ALTER TABLE "access_tokens"
+  ADD COLUMN refresh_token_id BIGINT REFERENCES refresh_tokens (id) ON DELETE CASCADE;
+
+-- Add a flag whether the token was already used or not
+ALTER TABLE "access_tokens"
+  ADD COLUMN used BOOLEAN;