summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/auth.py124
-rw-r--r--synapse/handlers/device.py8
-rw-r--r--synapse/handlers/events.py5
-rw-r--r--synapse/handlers/federation.py61
-rw-r--r--synapse/handlers/initial_sync.py30
-rw-r--r--synapse/handlers/message.py8
-rw-r--r--synapse/handlers/oidc.py58
-rw-r--r--synapse/handlers/pagination.py3
-rw-r--r--synapse/handlers/presence.py2
-rw-r--r--synapse/handlers/register.py78
-rw-r--r--synapse/handlers/room.py165
-rw-r--r--synapse/handlers/room_summary.py14
-rw-r--r--synapse/handlers/sso.py4
-rw-r--r--synapse/handlers/sync.py285
14 files changed, 642 insertions, 203 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 4b66a9862f..61607cf2ba 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -18,6 +18,7 @@ import time
 import unicodedata
 import urllib.parse
 from binascii import crc32
+from http import HTTPStatus
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -38,6 +39,7 @@ import attr
 import bcrypt
 import pymacaroons
 import unpaddedbase64
+from pymacaroons.exceptions import MacaroonVerificationFailedException
 
 from twisted.web.server import Request
 
@@ -181,8 +183,11 @@ class LoginTokenAttributes:
 
     user_id = attr.ib(type=str)
 
-    # the SSO Identity Provider that the user authenticated with, to get this token
     auth_provider_id = attr.ib(type=str)
+    """The SSO Identity Provider that the user authenticated with, to get this token."""
+
+    auth_provider_session_id = attr.ib(type=Optional[str])
+    """The session ID advertised by the SSO Identity Provider."""
 
 
 class AuthHandler:
@@ -756,53 +761,109 @@ class AuthHandler:
     async def refresh_token(
         self,
         refresh_token: str,
-        valid_until_ms: Optional[int],
-    ) -> Tuple[str, str]:
+        access_token_valid_until_ms: Optional[int],
+        refresh_token_valid_until_ms: Optional[int],
+    ) -> Tuple[str, str, Optional[int]]:
         """
         Consumes a refresh token and generate both a new access token and a new refresh token from it.
 
         The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token.
 
+        The lifetime of both the access token and refresh token will be capped so that they
+        do not exceed the session's ultimate expiry time, if applicable.
+
         Args:
             refresh_token: The token to consume.
-            valid_until_ms: The expiration timestamp of the new access token.
-
+            access_token_valid_until_ms: The expiration timestamp of the new access token.
+                None if the access token does not expire.
+            refresh_token_valid_until_ms: The expiration timestamp of the new refresh token.
+                None if the refresh token does not expire.
         Returns:
-            A tuple containing the new access token and refresh token
+            A tuple containing:
+              - the new access token
+              - the new refresh token
+              - the actual expiry time of the access token, which may be earlier than
+                `access_token_valid_until_ms`.
         """
 
         # Verify the token signature first before looking up the token
         if not self._verify_refresh_token(refresh_token):
-            raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN)
+            raise SynapseError(
+                HTTPStatus.UNAUTHORIZED, "invalid refresh token", Codes.UNKNOWN_TOKEN
+            )
 
         existing_token = await self.store.lookup_refresh_token(refresh_token)
         if existing_token is None:
-            raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN)
+            raise SynapseError(
+                HTTPStatus.UNAUTHORIZED,
+                "refresh token does not exist",
+                Codes.UNKNOWN_TOKEN,
+            )
 
         if (
             existing_token.has_next_access_token_been_used
             or existing_token.has_next_refresh_token_been_refreshed
         ):
             raise SynapseError(
-                403, "refresh token isn't valid anymore", Codes.FORBIDDEN
+                HTTPStatus.FORBIDDEN,
+                "refresh token isn't valid anymore",
+                Codes.FORBIDDEN,
             )
 
+        now_ms = self._clock.time_msec()
+
+        if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms:
+
+            raise SynapseError(
+                HTTPStatus.FORBIDDEN,
+                "The supplied refresh token has expired",
+                Codes.FORBIDDEN,
+            )
+
+        if existing_token.ultimate_session_expiry_ts is not None:
+            # This session has a bounded lifetime, even across refreshes.
+
+            if access_token_valid_until_ms is not None:
+                access_token_valid_until_ms = min(
+                    access_token_valid_until_ms,
+                    existing_token.ultimate_session_expiry_ts,
+                )
+            else:
+                access_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
+
+            if refresh_token_valid_until_ms is not None:
+                refresh_token_valid_until_ms = min(
+                    refresh_token_valid_until_ms,
+                    existing_token.ultimate_session_expiry_ts,
+                )
+            else:
+                refresh_token_valid_until_ms = existing_token.ultimate_session_expiry_ts
+            if existing_token.ultimate_session_expiry_ts < now_ms:
+                raise SynapseError(
+                    HTTPStatus.FORBIDDEN,
+                    "The session has expired and can no longer be refreshed",
+                    Codes.FORBIDDEN,
+                )
+
         (
             new_refresh_token,
             new_refresh_token_id,
         ) = await self.create_refresh_token_for_user_id(
-            user_id=existing_token.user_id, device_id=existing_token.device_id
+            user_id=existing_token.user_id,
+            device_id=existing_token.device_id,
+            expiry_ts=refresh_token_valid_until_ms,
+            ultimate_session_expiry_ts=existing_token.ultimate_session_expiry_ts,
         )
         access_token = await self.create_access_token_for_user_id(
             user_id=existing_token.user_id,
             device_id=existing_token.device_id,
-            valid_until_ms=valid_until_ms,
+            valid_until_ms=access_token_valid_until_ms,
             refresh_token_id=new_refresh_token_id,
         )
         await self.store.replace_refresh_token(
             existing_token.token_id, new_refresh_token_id
         )
-        return access_token, new_refresh_token
+        return access_token, new_refresh_token, access_token_valid_until_ms
 
     def _verify_refresh_token(self, token: str) -> bool:
         """
@@ -836,6 +897,8 @@ class AuthHandler:
         self,
         user_id: str,
         device_id: str,
+        expiry_ts: Optional[int],
+        ultimate_session_expiry_ts: Optional[int],
     ) -> Tuple[str, int]:
         """
         Creates a new refresh token for the user with the given user ID.
@@ -843,6 +906,13 @@ class AuthHandler:
         Args:
             user_id: canonical user ID
             device_id: the device ID to associate with the token.
+            expiry_ts (milliseconds since the epoch): Time after which the
+                refresh token cannot be used.
+                If None, the refresh token never expires until it has been used.
+            ultimate_session_expiry_ts (milliseconds since the epoch):
+                Time at which the session will end and can not be extended any
+                further.
+                If None, the session can be refreshed indefinitely.
 
         Returns:
             The newly created refresh token and its ID in the database
@@ -852,6 +922,8 @@ class AuthHandler:
             user_id=user_id,
             token=refresh_token,
             device_id=device_id,
+            expiry_ts=expiry_ts,
+            ultimate_session_expiry_ts=ultimate_session_expiry_ts,
         )
         return refresh_token, refresh_token_id
 
@@ -1582,6 +1654,7 @@ class AuthHandler:
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
         new_user: bool = False,
+        auth_provider_session_id: Optional[str] = None,
     ) -> None:
         """Having figured out a mxid for this user, complete the HTTP request
 
@@ -1597,6 +1670,7 @@ class AuthHandler:
                 during successful login. Must be JSON serializable.
             new_user: True if we should use wording appropriate to a user who has just
                 registered.
+            auth_provider_session_id: The session ID from the SSO IdP received during login.
         """
         # If the account has been deactivated, do not proceed with the login
         # flow.
@@ -1617,6 +1691,7 @@ class AuthHandler:
             extra_attributes,
             new_user=new_user,
             user_profile_data=profile,
+            auth_provider_session_id=auth_provider_session_id,
         )
 
     def _complete_sso_login(
@@ -1628,6 +1703,7 @@ class AuthHandler:
         extra_attributes: Optional[JsonDict] = None,
         new_user: bool = False,
         user_profile_data: Optional[ProfileInfo] = None,
+        auth_provider_session_id: Optional[str] = None,
     ) -> None:
         """
         The synchronous portion of complete_sso_login.
@@ -1649,7 +1725,9 @@ class AuthHandler:
 
         # Create a login token
         login_token = self.macaroon_gen.generate_short_term_login_token(
-            registered_user_id, auth_provider_id=auth_provider_id
+            registered_user_id,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
         )
 
         # Append the login token to the original redirect URL (i.e. with its query
@@ -1754,6 +1832,7 @@ class MacaroonGenerator:
         self,
         user_id: str,
         auth_provider_id: str,
+        auth_provider_session_id: Optional[str] = None,
         duration_in_ms: int = (2 * 60 * 1000),
     ) -> str:
         macaroon = self._generate_base_macaroon(user_id)
@@ -1762,6 +1841,10 @@ class MacaroonGenerator:
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
         macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
+        if auth_provider_session_id is not None:
+            macaroon.add_first_party_caveat(
+                "auth_provider_session_id = %s" % (auth_provider_session_id,)
+            )
         return macaroon.serialize()
 
     def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
@@ -1783,15 +1866,28 @@ class MacaroonGenerator:
         user_id = get_value_from_macaroon(macaroon, "user_id")
         auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
 
+        auth_provider_session_id: Optional[str] = None
+        try:
+            auth_provider_session_id = get_value_from_macaroon(
+                macaroon, "auth_provider_session_id"
+            )
+        except MacaroonVerificationFailedException:
+            pass
+
         v = pymacaroons.Verifier()
         v.satisfy_exact("gen = 1")
         v.satisfy_exact("type = login")
         v.satisfy_general(lambda c: c.startswith("user_id = "))
         v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
+        v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
         satisfy_expiry(v, self.hs.get_clock().time_msec)
         v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
 
-        return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
+        return LoginTokenAttributes(
+            user_id=user_id,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
+        )
 
     def generate_delete_pusher_token(self, user_id: str) -> str:
         macaroon = self._generate_base_macaroon(user_id)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 68b446eb66..82ee11e921 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -301,6 +301,8 @@ class DeviceHandler(DeviceWorkerHandler):
         user_id: str,
         device_id: Optional[str],
         initial_device_display_name: Optional[str] = None,
+        auth_provider_id: Optional[str] = None,
+        auth_provider_session_id: Optional[str] = None,
     ) -> str:
         """
         If the given device has not been registered, register it with the
@@ -312,6 +314,8 @@ class DeviceHandler(DeviceWorkerHandler):
             user_id:  @user:id
             device_id: device id supplied by client
             initial_device_display_name: device display name from client
+            auth_provider_id: The SSO IdP the user used, if any.
+            auth_provider_session_id: The session ID (sid) got from the SSO IdP.
         Returns:
             device id (generated if none was supplied)
         """
@@ -323,6 +327,8 @@ class DeviceHandler(DeviceWorkerHandler):
                 user_id=user_id,
                 device_id=device_id,
                 initial_device_display_name=initial_device_display_name,
+                auth_provider_id=auth_provider_id,
+                auth_provider_session_id=auth_provider_session_id,
             )
             if new_device:
                 await self.notify_device_update(user_id, [device_id])
@@ -337,6 +343,8 @@ class DeviceHandler(DeviceWorkerHandler):
                 user_id=user_id,
                 device_id=new_device_id,
                 initial_device_display_name=initial_device_display_name,
+                auth_provider_id=auth_provider_id,
+                auth_provider_session_id=auth_provider_session_id,
             )
             if new_device:
                 await self.notify_device_update(user_id, [new_device_id])
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index b4ff935546..32b0254c5f 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -122,9 +122,8 @@ class EventStreamHandler:
                 events,
                 time_now,
                 as_client_event=as_client_event,
-                # We don't bundle "live" events, as otherwise clients
-                # will end up double counting annotations.
-                bundle_relations=False,
+                # Don't bundle aggregations as this is a deprecated API.
+                bundle_aggregations=False,
             )
 
             chunk = {
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 3112cc88b1..1ea837d082 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -68,6 +68,37 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
+    """Get joined domains from state
+
+    Args:
+        state: State map from type/state key to event.
+
+    Returns:
+        Returns a list of servers with the lowest depth of their joins.
+            Sorted by lowest depth first.
+    """
+    joined_users = [
+        (state_key, int(event.depth))
+        for (e_type, state_key), event in state.items()
+        if e_type == EventTypes.Member and event.membership == Membership.JOIN
+    ]
+
+    joined_domains: Dict[str, int] = {}
+    for u, d in joined_users:
+        try:
+            dom = get_domain_from_id(u)
+            old_d = joined_domains.get(dom)
+            if old_d:
+                joined_domains[dom] = min(d, old_d)
+            else:
+                joined_domains[dom] = d
+        except Exception:
+            pass
+
+    return sorted(joined_domains.items(), key=lambda d: d[1])
+
+
 class FederationHandler:
     """Handles general incoming federation requests
 
@@ -268,36 +299,6 @@ class FederationHandler:
 
         curr_state = await self.state_handler.get_current_state(room_id)
 
-        def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
-            """Get joined domains from state
-
-            Args:
-                state: State map from type/state key to event.
-
-            Returns:
-                Returns a list of servers with the lowest depth of their joins.
-                 Sorted by lowest depth first.
-            """
-            joined_users = [
-                (state_key, int(event.depth))
-                for (e_type, state_key), event in state.items()
-                if e_type == EventTypes.Member and event.membership == Membership.JOIN
-            ]
-
-            joined_domains: Dict[str, int] = {}
-            for u, d in joined_users:
-                try:
-                    dom = get_domain_from_id(u)
-                    old_d = joined_domains.get(dom)
-                    if old_d:
-                        joined_domains[dom] = min(d, old_d)
-                    else:
-                        joined_domains[dom] = d
-                except Exception:
-                    pass
-
-            return sorted(joined_domains.items(), key=lambda d: d[1])
-
         curr_domains = get_domains_from_state(curr_state)
 
         likely_domains = [
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index d4e4556155..9cd21e7f2b 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -165,7 +165,11 @@ class InitialSyncHandler:
 
                 invite_event = await self.store.get_event(event.event_id)
                 d["invite"] = await self._event_serializer.serialize_event(
-                    invite_event, time_now, as_client_event
+                    invite_event,
+                    time_now,
+                    # Don't bundle aggregations as this is a deprecated API.
+                    bundle_aggregations=False,
+                    as_client_event=as_client_event,
                 )
 
             rooms_ret.append(d)
@@ -216,7 +220,11 @@ class InitialSyncHandler:
                 d["messages"] = {
                     "chunk": (
                         await self._event_serializer.serialize_events(
-                            messages, time_now=time_now, as_client_event=as_client_event
+                            messages,
+                            time_now=time_now,
+                            # Don't bundle aggregations as this is a deprecated API.
+                            bundle_aggregations=False,
+                            as_client_event=as_client_event,
                         )
                     ),
                     "start": await start_token.to_string(self.store),
@@ -226,6 +234,8 @@ class InitialSyncHandler:
                 d["state"] = await self._event_serializer.serialize_events(
                     current_state.values(),
                     time_now=time_now,
+                    # Don't bundle aggregations as this is a deprecated API.
+                    bundle_aggregations=False,
                     as_client_event=as_client_event,
                 )
 
@@ -366,14 +376,18 @@ class InitialSyncHandler:
             "room_id": room_id,
             "messages": {
                 "chunk": (
-                    await self._event_serializer.serialize_events(messages, time_now)
+                    # Don't bundle aggregations as this is a deprecated API.
+                    await self._event_serializer.serialize_events(
+                        messages, time_now, bundle_aggregations=False
+                    )
                 ),
                 "start": await start_token.to_string(self.store),
                 "end": await end_token.to_string(self.store),
             },
             "state": (
+                # Don't bundle aggregations as this is a deprecated API.
                 await self._event_serializer.serialize_events(
-                    room_state.values(), time_now
+                    room_state.values(), time_now, bundle_aggregations=False
                 )
             ),
             "presence": [],
@@ -392,8 +406,9 @@ class InitialSyncHandler:
 
         # TODO: These concurrently
         time_now = self.clock.time_msec()
+        # Don't bundle aggregations as this is a deprecated API.
         state = await self._event_serializer.serialize_events(
-            current_state.values(), time_now
+            current_state.values(), time_now, bundle_aggregations=False
         )
 
         now_token = self.hs.get_event_sources().get_current_token()
@@ -467,7 +482,10 @@ class InitialSyncHandler:
             "room_id": room_id,
             "messages": {
                 "chunk": (
-                    await self._event_serializer.serialize_events(messages, time_now)
+                    # Don't bundle aggregations as this is a deprecated API.
+                    await self._event_serializer.serialize_events(
+                        messages, time_now, bundle_aggregations=False
+                    )
                 ),
                 "start": await start_token.to_string(self.store),
                 "end": await end_token.to_string(self.store),
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 95b4fad3c6..87f671708c 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -247,13 +247,7 @@ class MessageHandler:
                 room_state = room_state_events[membership_event_id]
 
         now = self.clock.time_msec()
-        events = await self._event_serializer.serialize_events(
-            room_state.values(),
-            now,
-            # We don't bother bundling aggregations in when asked for state
-            # events, as clients won't use them.
-            bundle_relations=False,
-        )
+        events = await self._event_serializer.serialize_events(room_state.values(), now)
         return events
 
     async def get_joined_members(self, requester: Requester, room_id: str) -> dict:
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 3665d91513..deb3539751 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -23,7 +23,7 @@ from authlib.common.security import generate_token
 from authlib.jose import JsonWebToken, jwt
 from authlib.oauth2.auth import ClientAuth
 from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
-from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
+from authlib.oidc.core import CodeIDToken, UserInfo
 from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
 from jinja2 import Environment, Template
 from pymacaroons.exceptions import (
@@ -117,7 +117,8 @@ class OidcHandler:
         for idp_id, p in self._providers.items():
             try:
                 await p.load_metadata()
-                await p.load_jwks()
+                if not p._uses_userinfo:
+                    await p.load_jwks()
             except Exception as e:
                 raise Exception(
                     "Error while initialising OIDC provider %r" % (idp_id,)
@@ -498,10 +499,6 @@ class OidcProvider:
         return await self._jwks.get()
 
     async def _load_jwks(self) -> JWKS:
-        if self._uses_userinfo:
-            # We're not using jwt signing, return an empty jwk set
-            return {"keys": []}
-
         metadata = await self.load_metadata()
 
         # Load the JWKS using the `jwks_uri` metadata.
@@ -663,7 +660,7 @@ class OidcProvider:
 
         return UserInfo(resp)
 
-    async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
+    async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
         """Return an instance of UserInfo from token's ``id_token``.
 
         Args:
@@ -673,7 +670,7 @@ class OidcProvider:
                 request. This value should match the one inside the token.
 
         Returns:
-            An object representing the user.
+            The decoded claims in the ID token.
         """
         metadata = await self.load_metadata()
         claims_params = {
@@ -684,9 +681,6 @@ class OidcProvider:
             # If we got an `access_token`, there should be an `at_hash` claim
             # in the `id_token` that we can check against.
             claims_params["access_token"] = token["access_token"]
-            claims_cls = CodeIDToken
-        else:
-            claims_cls = ImplicitIDToken
 
         alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
         jwt = JsonWebToken(alg_values)
@@ -703,7 +697,7 @@ class OidcProvider:
             claims = jwt.decode(
                 id_token,
                 key=jwk_set,
-                claims_cls=claims_cls,
+                claims_cls=CodeIDToken,
                 claims_options=claim_options,
                 claims_params=claims_params,
             )
@@ -713,7 +707,7 @@ class OidcProvider:
             claims = jwt.decode(
                 id_token,
                 key=jwk_set,
-                claims_cls=claims_cls,
+                claims_cls=CodeIDToken,
                 claims_options=claim_options,
                 claims_params=claims_params,
             )
@@ -721,7 +715,8 @@ class OidcProvider:
         logger.debug("Decoded id_token JWT %r; validating", claims)
 
         claims.validate(leeway=120)  # allows 2 min of clock skew
-        return UserInfo(claims)
+
+        return claims
 
     async def handle_redirect_request(
         self,
@@ -837,8 +832,22 @@ class OidcProvider:
 
         logger.debug("Successfully obtained OAuth2 token data: %r", token)
 
-        # Now that we have a token, get the userinfo, either by decoding the
-        # `id_token` or by fetching the `userinfo_endpoint`.
+        # If there is an id_token, it should be validated, regardless of the
+        # userinfo endpoint is used or not.
+        if token.get("id_token") is not None:
+            try:
+                id_token = await self._parse_id_token(token, nonce=session_data.nonce)
+                sid = id_token.get("sid")
+            except Exception as e:
+                logger.exception("Invalid id_token")
+                self._sso_handler.render_error(request, "invalid_token", str(e))
+                return
+        else:
+            id_token = None
+            sid = None
+
+        # Now that we have a token, get the userinfo either from the `id_token`
+        # claims or by fetching the `userinfo_endpoint`.
         if self._uses_userinfo:
             try:
                 userinfo = await self._fetch_userinfo(token)
@@ -846,13 +855,14 @@ class OidcProvider:
                 logger.exception("Could not fetch userinfo")
                 self._sso_handler.render_error(request, "fetch_error", str(e))
                 return
+        elif id_token is not None:
+            userinfo = UserInfo(id_token)
         else:
-            try:
-                userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
-            except Exception as e:
-                logger.exception("Invalid id_token")
-                self._sso_handler.render_error(request, "invalid_token", str(e))
-                return
+            logger.error("Missing id_token in token response")
+            self._sso_handler.render_error(
+                request, "invalid_token", "Missing id_token in token response"
+            )
+            return
 
         # first check if we're doing a UIA
         if session_data.ui_auth_session_id:
@@ -884,7 +894,7 @@ class OidcProvider:
         # Call the mapper to register/login the user
         try:
             await self._complete_oidc_login(
-                userinfo, token, request, session_data.client_redirect_url
+                userinfo, token, request, session_data.client_redirect_url, sid
             )
         except MappingException as e:
             logger.exception("Could not map user")
@@ -896,6 +906,7 @@ class OidcProvider:
         token: Token,
         request: SynapseRequest,
         client_redirect_url: str,
+        sid: Optional[str],
     ) -> None:
         """Given a UserInfo response, complete the login flow
 
@@ -1008,6 +1019,7 @@ class OidcProvider:
             oidc_response_to_user_attributes,
             grandfather_existing_users,
             extra_attributes,
+            auth_provider_session_id=sid,
         )
 
     def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index cd64142735..4f42438053 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -406,9 +406,6 @@ class PaginationHandler:
             force: set true to skip checking for joined users.
         """
         with await self.pagination_lock.write(room_id):
-            # check we know about the room
-            await self.store.get_room_version_id(room_id)
-
             # first check that we have no users in this room
             if not force:
                 joined = await self.store.is_host_joined(room_id, self._server_name)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 3df872c578..454d06c973 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -421,7 +421,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
             self._on_shutdown,
         )
 
-    def _on_shutdown(self) -> None:
+    async def _on_shutdown(self) -> None:
         if self._presence_enabled:
             self.hs.get_tcp_replication().send_command(
                 ClearUserSyncsCommand(self.instance_id)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 448a36108e..f08a516a75 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -1,4 +1,5 @@
 # Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -116,9 +117,13 @@ class RegistrationHandler:
             self.pusher_pool = hs.get_pusherpool()
 
         self.session_lifetime = hs.config.registration.session_lifetime
+        self.nonrefreshable_access_token_lifetime = (
+            hs.config.registration.nonrefreshable_access_token_lifetime
+        )
         self.refreshable_access_token_lifetime = (
             hs.config.registration.refreshable_access_token_lifetime
         )
+        self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime
 
         init_counters_for_auth_provider("")
 
@@ -741,6 +746,7 @@ class RegistrationHandler:
         is_appservice_ghost: bool = False,
         auth_provider_id: Optional[str] = None,
         should_issue_refresh_token: bool = False,
+        auth_provider_session_id: Optional[str] = None,
     ) -> Tuple[str, str, Optional[int], Optional[str]]:
         """Register a device for a user and generate an access token.
 
@@ -751,9 +757,9 @@ class RegistrationHandler:
             device_id: The device ID to check, or None to generate a new one.
             initial_display_name: An optional display name for the device.
             is_guest: Whether this is a guest account
-            auth_provider_id: The SSO IdP the user used, if any (just used for the
-                prometheus metrics).
+            auth_provider_id: The SSO IdP the user used, if any.
             should_issue_refresh_token: Whether it should also issue a refresh token
+            auth_provider_session_id: The session ID received during login from the SSO IdP.
         Returns:
             Tuple of device ID, access token, access token expiration time and refresh token
         """
@@ -764,6 +770,8 @@ class RegistrationHandler:
             is_guest=is_guest,
             is_appservice_ghost=is_appservice_ghost,
             should_issue_refresh_token=should_issue_refresh_token,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
         )
 
         login_counter.labels(
@@ -786,6 +794,8 @@ class RegistrationHandler:
         is_guest: bool = False,
         is_appservice_ghost: bool = False,
         should_issue_refresh_token: bool = False,
+        auth_provider_id: Optional[str] = None,
+        auth_provider_session_id: Optional[str] = None,
     ) -> LoginDict:
         """Helper for register_device
 
@@ -793,40 +803,86 @@ class RegistrationHandler:
         class and RegisterDeviceReplicationServlet.
         """
         assert not self.hs.config.worker.worker_app
-        valid_until_ms = None
+        now_ms = self.clock.time_msec()
+        access_token_expiry = None
         if self.session_lifetime is not None:
             if is_guest:
                 raise Exception(
                     "session_lifetime is not currently implemented for guest access"
                 )
-            valid_until_ms = self.clock.time_msec() + self.session_lifetime
+            access_token_expiry = now_ms + self.session_lifetime
+
+        if self.nonrefreshable_access_token_lifetime is not None:
+            if access_token_expiry is not None:
+                # Don't allow the non-refreshable access token to outlive the
+                # session.
+                access_token_expiry = min(
+                    now_ms + self.nonrefreshable_access_token_lifetime,
+                    access_token_expiry,
+                )
+            else:
+                access_token_expiry = now_ms + self.nonrefreshable_access_token_lifetime
 
         refresh_token = None
         refresh_token_id = None
 
         registered_device_id = await self.device_handler.check_device_registered(
-            user_id, device_id, initial_display_name
+            user_id,
+            device_id,
+            initial_display_name,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
         )
         if is_guest:
-            assert valid_until_ms is None
+            assert access_token_expiry is None
             access_token = self.macaroon_gen.generate_guest_access_token(user_id)
         else:
             if should_issue_refresh_token:
+                # A refreshable access token lifetime must be configured
+                # since we're told to issue a refresh token (the caller checks
+                # that this value is set before setting this flag).
+                assert self.refreshable_access_token_lifetime is not None
+
+                # Set the expiry time of the refreshable access token
+                access_token_expiry = now_ms + self.refreshable_access_token_lifetime
+
+                # Set the refresh token expiry time (if configured)
+                refresh_token_expiry = None
+                if self.refresh_token_lifetime is not None:
+                    refresh_token_expiry = now_ms + self.refresh_token_lifetime
+
+                # Set an ultimate session expiry time (if configured)
+                ultimate_session_expiry_ts = None
+                if self.session_lifetime is not None:
+                    ultimate_session_expiry_ts = now_ms + self.session_lifetime
+
+                    # Also ensure that the issued tokens don't outlive the
+                    # session.
+                    # (It would be weird to configure a homeserver with a shorter
+                    # session lifetime than token lifetime, but may as well handle
+                    # it.)
+                    access_token_expiry = min(
+                        access_token_expiry, ultimate_session_expiry_ts
+                    )
+                    if refresh_token_expiry is not None:
+                        refresh_token_expiry = min(
+                            refresh_token_expiry, ultimate_session_expiry_ts
+                        )
+
                 (
                     refresh_token,
                     refresh_token_id,
                 ) = await self._auth_handler.create_refresh_token_for_user_id(
                     user_id,
                     device_id=registered_device_id,
-                )
-                valid_until_ms = (
-                    self.clock.time_msec() + self.refreshable_access_token_lifetime
+                    expiry_ts=refresh_token_expiry,
+                    ultimate_session_expiry_ts=ultimate_session_expiry_ts,
                 )
 
             access_token = await self._auth_handler.create_access_token_for_user_id(
                 user_id,
                 device_id=registered_device_id,
-                valid_until_ms=valid_until_ms,
+                valid_until_ms=access_token_expiry,
                 is_appservice_ghost=is_appservice_ghost,
                 refresh_token_id=refresh_token_id,
             )
@@ -834,7 +890,7 @@ class RegistrationHandler:
         return {
             "device_id": registered_device_id,
             "access_token": access_token,
-            "valid_until_ms": valid_until_ms,
+            "valid_until_ms": access_token_expiry,
             "refresh_token": refresh_token,
         }
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 88053f9869..ead2198e14 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -46,6 +46,7 @@ from synapse.api.constants import (
 from synapse.api.errors import (
     AuthError,
     Codes,
+    HttpResponseException,
     LimitExceededError,
     NotFoundError,
     StoreError,
@@ -56,6 +57,8 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase
 from synapse.events.utils import copy_power_levels_contents
+from synapse.federation.federation_client import InvalidResponseError
+from synapse.handlers.federation import get_domains_from_state
 from synapse.rest.admin._base import assert_user_is_admin
 from synapse.storage.state import StateFilter
 from synapse.streams import EventSource
@@ -1220,6 +1223,147 @@ class RoomContextHandler:
         return results
 
 
+class TimestampLookupHandler:
+    def __init__(self, hs: "HomeServer"):
+        self.server_name = hs.hostname
+        self.store = hs.get_datastore()
+        self.state_handler = hs.get_state_handler()
+        self.federation_client = hs.get_federation_client()
+
+    async def get_event_for_timestamp(
+        self,
+        requester: Requester,
+        room_id: str,
+        timestamp: int,
+        direction: str,
+    ) -> Tuple[str, int]:
+        """Find the closest event to the given timestamp in the given direction.
+        If we can't find an event locally or the event we have locally is next to a gap,
+        it will ask other federated homeservers for an event.
+
+        Args:
+            requester: The user making the request according to the access token
+            room_id: Room to fetch the event from
+            timestamp: The point in time (inclusive) we should navigate from in
+                the given direction to find the closest event.
+            direction: ["f"|"b"] to indicate whether we should navigate forward
+                or backward from the given timestamp to find the closest event.
+
+        Returns:
+            A tuple containing the `event_id` closest to the given timestamp in
+            the given direction and the `origin_server_ts`.
+
+        Raises:
+            SynapseError if unable to find any event locally in the given direction
+        """
+
+        local_event_id = await self.store.get_event_id_for_timestamp(
+            room_id, timestamp, direction
+        )
+        logger.debug(
+            "get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s",
+            local_event_id,
+            timestamp,
+        )
+
+        # Check for gaps in the history where events could be hiding in between
+        # the timestamp given and the event we were able to find locally
+        is_event_next_to_backward_gap = False
+        is_event_next_to_forward_gap = False
+        if local_event_id:
+            local_event = await self.store.get_event(
+                local_event_id, allow_none=False, allow_rejected=False
+            )
+
+            if direction == "f":
+                # We only need to check for a backward gap if we're looking forwards
+                # to ensure there is nothing in between.
+                is_event_next_to_backward_gap = (
+                    await self.store.is_event_next_to_backward_gap(local_event)
+                )
+            elif direction == "b":
+                # We only need to check for a forward gap if we're looking backwards
+                # to ensure there is nothing in between
+                is_event_next_to_forward_gap = (
+                    await self.store.is_event_next_to_forward_gap(local_event)
+                )
+
+        # If we found a gap, we should probably ask another homeserver first
+        # about more history in between
+        if (
+            not local_event_id
+            or is_event_next_to_backward_gap
+            or is_event_next_to_forward_gap
+        ):
+            logger.debug(
+                "get_event_for_timestamp: locally, we found event_id=%s closest to timestamp=%s which is next to a gap in event history so we're asking other homeservers first",
+                local_event_id,
+                timestamp,
+            )
+
+            # Find other homeservers from the given state in the room
+            curr_state = await self.state_handler.get_current_state(room_id)
+            curr_domains = get_domains_from_state(curr_state)
+            likely_domains = [
+                domain for domain, depth in curr_domains if domain != self.server_name
+            ]
+
+            # Loop through each homeserver candidate until we get a succesful response
+            for domain in likely_domains:
+                try:
+                    remote_response = await self.federation_client.timestamp_to_event(
+                        domain, room_id, timestamp, direction
+                    )
+                    logger.debug(
+                        "get_event_for_timestamp: response from domain(%s)=%s",
+                        domain,
+                        remote_response,
+                    )
+
+                    # TODO: Do we want to persist this as an extremity?
+                    # TODO: I think ideally, we would try to backfill from
+                    # this event and run this whole
+                    # `get_event_for_timestamp` function again to make sure
+                    # they didn't give us an event from their gappy history.
+                    remote_event_id = remote_response.event_id
+                    origin_server_ts = remote_response.origin_server_ts
+
+                    # Only return the remote event if it's closer than the local event
+                    if not local_event or (
+                        abs(origin_server_ts - timestamp)
+                        < abs(local_event.origin_server_ts - timestamp)
+                    ):
+                        return remote_event_id, origin_server_ts
+                except (HttpResponseException, InvalidResponseError) as ex:
+                    # Let's not put a high priority on some other homeserver
+                    # failing to respond or giving a random response
+                    logger.debug(
+                        "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
+                        domain,
+                        type(ex).__name__,
+                        ex,
+                        ex.args,
+                    )
+                except Exception as ex:
+                    # But we do want to see some exceptions in our code
+                    logger.warning(
+                        "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
+                        domain,
+                        type(ex).__name__,
+                        ex,
+                        ex.args,
+                    )
+
+        if not local_event_id:
+            raise SynapseError(
+                404,
+                "Unable to find event from %s in direction %s" % (timestamp, direction),
+                errcode=Codes.NOT_FOUND,
+            )
+
+        return local_event_id, local_event.origin_server_ts
+
+
 class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
@@ -1391,20 +1535,13 @@ class RoomShutdownHandler:
             await self.store.block_room(room_id, requester_user_id)
 
         if not await self.store.get_room(room_id):
-            if block:
-                # We allow you to block an unknown room.
-                return {
-                    "kicked_users": [],
-                    "failed_to_kick_users": [],
-                    "local_aliases": [],
-                    "new_room_id": None,
-                }
-            else:
-                # But if you don't want to preventatively block another room,
-                # this function can't do anything useful.
-                raise NotFoundError(
-                    "Cannot shut down room: unknown room id %s" % (room_id,)
-                )
+            # if we don't know about the room, there is nothing left to do.
+            return {
+                "kicked_users": [],
+                "failed_to_kick_users": [],
+                "local_aliases": [],
+                "new_room_id": None,
+            }
 
         if new_room_user_id is not None:
             if not self.hs.is_mine_id(new_room_user_id):
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 8181cc0b52..b2cfe537df 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -36,8 +36,9 @@ from synapse.api.errors import (
     SynapseError,
     UnsupportedRoomVersionError,
 )
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.events import EventBase
-from synapse.types import JsonDict
+from synapse.types import JsonDict, Requester
 from synapse.util.caches.response_cache import ResponseCache
 
 if TYPE_CHECKING:
@@ -93,6 +94,9 @@ class RoomSummaryHandler:
         self._event_serializer = hs.get_event_client_serializer()
         self._server_name = hs.hostname
         self._federation_client = hs.get_federation_client()
+        self._ratelimiter = Ratelimiter(
+            store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10
+        )
 
         # If a user tries to fetch the same page multiple times in quick succession,
         # only process the first attempt and return its result to subsequent requests.
@@ -249,7 +253,7 @@ class RoomSummaryHandler:
 
     async def get_room_hierarchy(
         self,
-        requester: str,
+        requester: Requester,
         requested_room_id: str,
         suggested_only: bool = False,
         max_depth: Optional[int] = None,
@@ -276,6 +280,8 @@ class RoomSummaryHandler:
         Returns:
             The JSON hierarchy dictionary.
         """
+        await self._ratelimiter.ratelimit(requester)
+
         # If a user tries to fetch the same page multiple times in quick succession,
         # only process the first attempt and return its result to subsequent requests.
         #
@@ -283,7 +289,7 @@ class RoomSummaryHandler:
         # to process multiple requests for the same page will result in errors.
         return await self._pagination_response_cache.wrap(
             (
-                requester,
+                requester.user.to_string(),
                 requested_room_id,
                 suggested_only,
                 max_depth,
@@ -291,7 +297,7 @@ class RoomSummaryHandler:
                 from_token,
             ),
             self._get_room_hierarchy,
-            requester,
+            requester.user.to_string(),
             requested_room_id,
             suggested_only,
             max_depth,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 49fde01cf0..65c27bc64a 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -365,6 +365,7 @@ class SsoHandler:
         sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
         grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
         extra_login_attributes: Optional[JsonDict] = None,
+        auth_provider_session_id: Optional[str] = None,
     ) -> None:
         """
         Given an SSO ID, retrieve the user ID for it and possibly register the user.
@@ -415,6 +416,8 @@ class SsoHandler:
             extra_login_attributes: An optional dictionary of extra
                 attributes to be provided to the client in the login response.
 
+            auth_provider_session_id: An optional session ID from the IdP.
+
         Raises:
             MappingException if there was a problem mapping the response to a user.
             RedirectException: if the mapping provider needs to redirect the user
@@ -490,6 +493,7 @@ class SsoHandler:
             client_redirect_url,
             extra_login_attributes,
             new_user=new_user,
+            auth_provider_session_id=auth_provider_session_id,
         )
 
     async def _call_attribute_mapper(
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 891435c14d..f3039c3c3f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -334,6 +334,19 @@ class SyncHandler:
         full_state: bool,
         cache_context: ResponseCacheContext[SyncRequestKey],
     ) -> SyncResult:
+        """The start of the machinery that produces a /sync response.
+
+        See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details.
+
+        This method does high-level bookkeeping:
+        - tracking the kind of sync in the logging context
+        - deleting any to_device messages whose delivery has been acknowledged.
+        - deciding if we should dispatch an instant or delayed response
+        - marking the sync as being lazily loaded, if appropriate
+
+        Computing the body of the response begins in the next method,
+        `current_sync_for_user`.
+        """
         if since_token is None:
             sync_type = "initial_sync"
         elif full_state:
@@ -363,7 +376,7 @@ class SyncHandler:
                 sync_config, since_token, full_state=full_state
             )
         else:
-
+            # Otherwise, we wait for something to happen and report it to the user.
             async def current_sync_callback(
                 before_token: StreamToken, after_token: StreamToken
             ) -> SyncResult:
@@ -402,7 +415,12 @@ class SyncHandler:
         since_token: Optional[StreamToken] = None,
         full_state: bool = False,
     ) -> SyncResult:
-        """Get the sync for client needed to match what the server has now."""
+        """Generates the response body of a sync result, represented as a SyncResult.
+
+        This is a wrapper around `generate_sync_result` which starts an open tracing
+        span to track the sync. See `generate_sync_result` for the next part of your
+        indoctrination.
+        """
         with start_active_span("current_sync_for_user"):
             log_kv({"since_token": since_token})
             sync_result = await self.generate_sync_result(
@@ -560,7 +578,7 @@ class SyncHandler:
                 # that have happened since `since_key` up to `end_key`, so we
                 # can just use `get_room_events_stream_for_room`.
                 # Otherwise, we want to return the last N events in the room
-                # in toplogical ordering.
+                # in topological ordering.
                 if since_key:
                     events, end_key = await self.store.get_room_events_stream_for_room(
                         room_id,
@@ -1042,7 +1060,18 @@ class SyncHandler:
         since_token: Optional[StreamToken] = None,
         full_state: bool = False,
     ) -> SyncResult:
-        """Generates a sync result."""
+        """Generates the response body of a sync result.
+
+        This is represented by a `SyncResult` struct, which is built from small pieces
+        using a `SyncResultBuilder`. See also
+            https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync
+        the `sync_result_builder` is passed as a mutable ("inout") parameter to various
+        helper functions. These retrieve and process the data which forms the sync body,
+        often writing to the `sync_result_builder` to store their output.
+
+        At the end, we transfer data from the `sync_result_builder` to a new `SyncResult`
+        instance to signify that the sync calculation is complete.
+        """
         # NB: The now_token gets changed by some of the generate_sync_* methods,
         # this is due to some of the underlying streams not supporting the ability
         # to query up to a given point.
@@ -1344,14 +1373,22 @@ class SyncHandler:
     async def _generate_sync_entry_for_account_data(
         self, sync_result_builder: "SyncResultBuilder"
     ) -> Dict[str, Dict[str, JsonDict]]:
-        """Generates the account data portion of the sync response. Populates
-        `sync_result_builder` with the result.
+        """Generates the account data portion of the sync response.
+
+        Account data (called "Client Config" in the spec) can be set either globally
+        or for a specific room. Account data consists of a list of events which
+        accumulate state, much like a room.
+
+        This function retrieves global and per-room account data. The former is written
+        to the given `sync_result_builder`. The latter is returned directly, to be
+        later written to the `sync_result_builder` on a room-by-room basis.
 
         Args:
             sync_result_builder
 
         Returns:
-            A dictionary containing the per room account data.
+            A dictionary whose keys (room ids) map to the per room account data for that
+            room.
         """
         sync_config = sync_result_builder.sync_config
         user_id = sync_result_builder.sync_config.user.to_string()
@@ -1359,7 +1396,7 @@ class SyncHandler:
 
         if since_token and not sync_result_builder.full_state:
             (
-                account_data,
+                global_account_data,
                 account_data_by_room,
             ) = await self.store.get_updated_account_data_for_user(
                 user_id, since_token.account_data_key
@@ -1370,23 +1407,23 @@ class SyncHandler:
             )
 
             if push_rules_changed:
-                account_data["m.push_rules"] = await self.push_rules_for_user(
+                global_account_data["m.push_rules"] = await self.push_rules_for_user(
                     sync_config.user
                 )
         else:
             (
-                account_data,
+                global_account_data,
                 account_data_by_room,
             ) = await self.store.get_account_data_for_user(sync_config.user.to_string())
 
-            account_data["m.push_rules"] = await self.push_rules_for_user(
+            global_account_data["m.push_rules"] = await self.push_rules_for_user(
                 sync_config.user
             )
 
         account_data_for_user = await sync_config.filter_collection.filter_account_data(
             [
                 {"type": account_data_type, "content": content}
-                for account_data_type, content in account_data.items()
+                for account_data_type, content in global_account_data.items()
             ]
         )
 
@@ -1460,18 +1497,31 @@ class SyncHandler:
         """Generates the rooms portion of the sync response. Populates the
         `sync_result_builder` with the result.
 
+        In the response that reaches the client, rooms are divided into four categories:
+        `invite`, `join`, `knock`, `leave`. These aren't the same as the four sets of
+        room ids returned by this function.
+
         Args:
             sync_result_builder
             account_data_by_room: Dictionary of per room account data
 
         Returns:
-            Returns a 4-tuple of
-            `(newly_joined_rooms, newly_joined_or_invited_users,
-            newly_left_rooms, newly_left_users)`
+            Returns a 4-tuple describing rooms the user has joined or left, and users who've
+            joined or left rooms any rooms the user is in. This gets used later in
+            `_generate_sync_entry_for_device_list`.
+
+            Its entries are:
+            - newly_joined_rooms
+            - newly_joined_or_invited_or_knocked_users
+            - newly_left_rooms
+            - newly_left_users
         """
+        since_token = sync_result_builder.since_token
+
+        # 1. Start by fetching all ephemeral events in rooms we've joined (if required).
         user_id = sync_result_builder.sync_config.user.to_string()
         block_all_room_ephemeral = (
-            sync_result_builder.since_token is None
+            since_token is None
             and sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
         )
 
@@ -1485,9 +1535,8 @@ class SyncHandler:
             )
             sync_result_builder.now_token = now_token
 
-        # We check up front if anything has changed, if it hasn't then there is
+        # 2. We check up front if anything has changed, if it hasn't then there is
         # no point in going further.
-        since_token = sync_result_builder.since_token
         if not sync_result_builder.full_state:
             if since_token and not ephemeral_by_room and not account_data_by_room:
                 have_changed = await self._have_rooms_changed(sync_result_builder)
@@ -1500,20 +1549,8 @@ class SyncHandler:
                         logger.debug("no-oping sync")
                         return set(), set(), set(), set()
 
-        ignored_account_data = (
-            await self.store.get_global_account_data_by_type_for_user(
-                AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
-            )
-        )
-
-        # If there is ignored users account data and it matches the proper type,
-        # then use it.
-        ignored_users: FrozenSet[str] = frozenset()
-        if ignored_account_data:
-            ignored_users_data = ignored_account_data.get("ignored_users", {})
-            if isinstance(ignored_users_data, dict):
-                ignored_users = frozenset(ignored_users_data.keys())
-
+        # 3. Work out which rooms need reporting in the sync response.
+        ignored_users = await self._get_ignored_users(user_id)
         if since_token:
             room_changes = await self._get_rooms_changed(
                 sync_result_builder, ignored_users
@@ -1523,7 +1560,6 @@ class SyncHandler:
             )
         else:
             room_changes = await self._get_all_rooms(sync_result_builder, ignored_users)
-
             tags_by_room = await self.store.get_tags_for_user(user_id)
 
         log_kv({"rooms_changed": len(room_changes.room_entries)})
@@ -1534,6 +1570,8 @@ class SyncHandler:
         newly_joined_rooms = room_changes.newly_joined_rooms
         newly_left_rooms = room_changes.newly_left_rooms
 
+        # 4. We need to apply further processing to `room_entries` (rooms considered
+        # joined or archived).
         async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None:
             logger.debug("Generating room entry for %s", room_entry.room_id)
             await self._generate_room_entry(
@@ -1552,31 +1590,13 @@ class SyncHandler:
         sync_result_builder.invited.extend(invited)
         sync_result_builder.knocked.extend(knocked)
 
-        # Now we want to get any newly joined, invited or knocking users
-        newly_joined_or_invited_or_knocked_users = set()
-        newly_left_users = set()
-        if since_token:
-            for joined_sync in sync_result_builder.joined:
-                it = itertools.chain(
-                    joined_sync.timeline.events, joined_sync.state.values()
-                )
-                for event in it:
-                    if event.type == EventTypes.Member:
-                        if (
-                            event.membership == Membership.JOIN
-                            or event.membership == Membership.INVITE
-                            or event.membership == Membership.KNOCK
-                        ):
-                            newly_joined_or_invited_or_knocked_users.add(
-                                event.state_key
-                            )
-                        else:
-                            prev_content = event.unsigned.get("prev_content", {})
-                            prev_membership = prev_content.get("membership", None)
-                            if prev_membership == Membership.JOIN:
-                                newly_left_users.add(event.state_key)
-
-        newly_left_users -= newly_joined_or_invited_or_knocked_users
+        # 5. Work out which users have joined or left rooms we're in. We use this
+        # to build the device_list part of the sync response in
+        # `_generate_sync_entry_for_device_list`.
+        (
+            newly_joined_or_invited_or_knocked_users,
+            newly_left_users,
+        ) = sync_result_builder.calculate_user_changes()
 
         return (
             set(newly_joined_rooms),
@@ -1585,11 +1605,36 @@ class SyncHandler:
             newly_left_users,
         )
 
+    async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]:
+        """Retrieve the users ignored by the given user from their global account_data.
+
+        Returns an empty set if
+        - there is no global account_data entry for ignored_users
+        - there is such an entry, but it's not a JSON object.
+        """
+        # TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead?
+        ignored_account_data = (
+            await self.store.get_global_account_data_by_type_for_user(
+                AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
+            )
+        )
+
+        # If there is ignored users account data and it matches the proper type,
+        # then use it.
+        ignored_users: FrozenSet[str] = frozenset()
+        if ignored_account_data:
+            ignored_users_data = ignored_account_data.get("ignored_users", {})
+            if isinstance(ignored_users_data, dict):
+                ignored_users = frozenset(ignored_users_data.keys())
+        return ignored_users
+
     async def _have_rooms_changed(
         self, sync_result_builder: "SyncResultBuilder"
     ) -> bool:
         """Returns whether there may be any new events that should be sent down
         the sync. Returns True if there are.
+
+        Does not modify the `sync_result_builder`.
         """
         user_id = sync_result_builder.sync_config.user.to_string()
         since_token = sync_result_builder.since_token
@@ -1597,12 +1642,13 @@ class SyncHandler:
 
         assert since_token
 
-        # Get a list of membership change events that have happened.
-        rooms_changed = await self.store.get_membership_changes_for_user(
+        # Get a list of membership change events that have happened to the user
+        # requesting the sync.
+        membership_changes = await self.store.get_membership_changes_for_user(
             user_id, since_token.room_key, now_token.room_key
         )
 
-        if rooms_changed:
+        if membership_changes:
             return True
 
         stream_id = since_token.room_key.stream
@@ -1614,7 +1660,25 @@ class SyncHandler:
     async def _get_rooms_changed(
         self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
     ) -> _RoomChanges:
-        """Gets the the changes that have happened since the last sync."""
+        """Determine the changes in rooms to report to the user.
+
+        Ideally, we want to report all events whose stream ordering `s` lies in the
+        range `since_token < s <= now_token`, where the two tokens are read from the
+        sync_result_builder.
+
+        If there are too many events in that range to report, things get complicated.
+        In this situation we return a truncated list of the most recent events, and
+        indicate in the response that there is a "gap" of omitted events. Additionally:
+
+        - we include a "state_delta", to describe the changes in state over the gap,
+        - we include all membership events applying to the user making the request,
+          even those in the gap.
+
+        See the spec for the rationale:
+            https://spec.matrix.org/v1.1/client-server-api/#syncing
+
+        The sync_result_builder is not modified by this function.
+        """
         user_id = sync_result_builder.sync_config.user.to_string()
         since_token = sync_result_builder.since_token
         now_token = sync_result_builder.now_token
@@ -1622,21 +1686,36 @@ class SyncHandler:
 
         assert since_token
 
-        # Get a list of membership change events that have happened.
-        rooms_changed = await self.store.get_membership_changes_for_user(
+        # The spec
+        #     https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync
+        # notes that membership events need special consideration:
+        #
+        # > When a sync is limited, the server MUST return membership events for events
+        # > in the gap (between since and the start of the returned timeline), regardless
+        # > as to whether or not they are redundant.
+        #
+        # We fetch such events here, but we only seem to use them for categorising rooms
+        # as newly joined, newly left, invited or knocked.
+        # TODO: we've already called this function and ran this query in
+        #       _have_rooms_changed. We could keep the results in memory to avoid a
+        #       second query, at the cost of more complicated source code.
+        membership_change_events = await self.store.get_membership_changes_for_user(
             user_id, since_token.room_key, now_token.room_key
         )
 
         mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
-        for event in rooms_changed:
+        for event in membership_change_events:
             mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
 
-        newly_joined_rooms = []
-        newly_left_rooms = []
-        room_entries = []
-        invited = []
-        knocked = []
+        newly_joined_rooms: List[str] = []
+        newly_left_rooms: List[str] = []
+        room_entries: List[RoomSyncResultBuilder] = []
+        invited: List[InvitedSyncResult] = []
+        knocked: List[KnockedSyncResult] = []
         for room_id, events in mem_change_events_by_room_id.items():
+            # The body of this loop will add this room to at least one of the five lists
+            # above. Things get messy if you've e.g. joined, left, joined then left the
+            # room all in the same sync period.
             logger.debug(
                 "Membership changes in %s: [%s]",
                 room_id,
@@ -1691,6 +1770,7 @@ class SyncHandler:
 
             if not non_joins:
                 continue
+            last_non_join = non_joins[-1]
 
             # Check if we have left the room. This can either be because we were
             # joined before *or* that we since joined and then left.
@@ -1712,18 +1792,18 @@ class SyncHandler:
                         newly_left_rooms.append(room_id)
 
             # Only bother if we're still currently invited
-            should_invite = non_joins[-1].membership == Membership.INVITE
+            should_invite = last_non_join.membership == Membership.INVITE
             if should_invite:
-                if event.sender not in ignored_users:
-                    invite_room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
+                if last_non_join.sender not in ignored_users:
+                    invite_room_sync = InvitedSyncResult(room_id, invite=last_non_join)
                     if invite_room_sync:
                         invited.append(invite_room_sync)
 
             # Only bother if our latest membership in the room is knock (and we haven't
             # been accepted/rejected in the meantime).
-            should_knock = non_joins[-1].membership == Membership.KNOCK
+            should_knock = last_non_join.membership == Membership.KNOCK
             if should_knock:
-                knock_room_sync = KnockedSyncResult(room_id, knock=non_joins[-1])
+                knock_room_sync = KnockedSyncResult(room_id, knock=last_non_join)
                 if knock_room_sync:
                     knocked.append(knock_room_sync)
 
@@ -1781,7 +1861,9 @@ class SyncHandler:
 
         timeline_limit = sync_config.filter_collection.timeline_limit()
 
-        # Get all events for rooms we're currently joined to.
+        # Get all events since the `from_key` in rooms we're currently joined to.
+        # If there are too many, we get the most recent events only. This leaves
+        # a "gap" in the timeline, as described by the spec for /sync.
         room_to_events = await self.store.get_room_events_stream_for_rooms(
             room_ids=sync_result_builder.joined_room_ids,
             from_key=since_token.room_key,
@@ -1842,6 +1924,10 @@ class SyncHandler:
     ) -> _RoomChanges:
         """Returns entries for all rooms for the user.
 
+        Like `_get_rooms_changed`, but assumes the `since_token` is `None`.
+
+        This function does not modify the sync_result_builder.
+
         Args:
             sync_result_builder
             ignored_users: Set of users ignored by user.
@@ -1853,16 +1939,9 @@ class SyncHandler:
         now_token = sync_result_builder.now_token
         sync_config = sync_result_builder.sync_config
 
-        membership_list = (
-            Membership.INVITE,
-            Membership.KNOCK,
-            Membership.JOIN,
-            Membership.LEAVE,
-            Membership.BAN,
-        )
-
         room_list = await self.store.get_rooms_for_local_user_where_membership_is(
-            user_id=user_id, membership_list=membership_list
+            user_id=user_id,
+            membership_list=Membership.LIST,
         )
 
         room_entries = []
@@ -2212,8 +2291,7 @@ def _calculate_state(
     # to only include membership events for the senders in the timeline.
     # In practice, we can do this by removing them from the p_ids list,
     # which is the list of relevant state we know we have already sent to the client.
-    # see https://github.com/matrix-org/synapse/pull/2970
-    #            /files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
+    # see https://github.com/matrix-org/synapse/pull/2970/files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
 
     if lazy_load_members:
         p_ids.difference_update(
@@ -2262,6 +2340,39 @@ class SyncResultBuilder:
     groups: Optional[GroupsSyncResult] = None
     to_device: List[JsonDict] = attr.Factory(list)
 
+    def calculate_user_changes(self) -> Tuple[Set[str], Set[str]]:
+        """Work out which other users have joined or left rooms we are joined to.
+
+        This data only is only useful for an incremental sync.
+
+        The SyncResultBuilder is not modified by this function.
+        """
+        newly_joined_or_invited_or_knocked_users = set()
+        newly_left_users = set()
+        if self.since_token:
+            for joined_sync in self.joined:
+                it = itertools.chain(
+                    joined_sync.timeline.events, joined_sync.state.values()
+                )
+                for event in it:
+                    if event.type == EventTypes.Member:
+                        if (
+                            event.membership == Membership.JOIN
+                            or event.membership == Membership.INVITE
+                            or event.membership == Membership.KNOCK
+                        ):
+                            newly_joined_or_invited_or_knocked_users.add(
+                                event.state_key
+                            )
+                        else:
+                            prev_content = event.unsigned.get("prev_content", {})
+                            prev_membership = prev_content.get("membership", None)
+                            if prev_membership == Membership.JOIN:
+                                newly_left_users.add(event.state_key)
+
+        newly_left_users -= newly_joined_or_invited_or_knocked_users
+        return newly_joined_or_invited_or_knocked_users, newly_left_users
+
 
 @attr.s(slots=True, auto_attribs=True)
 class RoomSyncResultBuilder: