summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
authorTravis Ralston <travpc@gmail.com>2021-02-02 07:26:36 -0700
committerTravis Ralston <travpc@gmail.com>2021-02-02 07:26:36 -0700
commit6ec034411cdad6f9f77f7dab7749d426ffd5744b (patch)
treed7935121ad99051eddf5769993bef674e618379e /synapse/handlers
parentAppease the linters (diff)
parentUpdate changelog (diff)
downloadsynapse-6ec034411cdad6f9f77f7dab7749d426ffd5744b.tar.xz
Merge branch 'develop' into travis/fosdem/admin-api-groups github/travis/fosdem/admin-api-groups travis/fosdem/admin-api-groups
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/acme.py12
-rw-r--r--synapse/handlers/acme_issuing_service.py27
-rw-r--r--synapse/handlers/auth.py38
-rw-r--r--synapse/handlers/cas_handler.py11
-rw-r--r--synapse/handlers/device.py12
-rw-r--r--synapse/handlers/e2e_keys.py223
-rw-r--r--synapse/handlers/e2e_room_keys.py91
-rw-r--r--synapse/handlers/federation.py9
-rw-r--r--synapse/handlers/groups_local.py83
-rw-r--r--synapse/handlers/identity.py30
-rw-r--r--synapse/handlers/message.py44
-rw-r--r--synapse/handlers/oidc_handler.py66
-rw-r--r--synapse/handlers/register.py7
-rw-r--r--synapse/handlers/room.py9
-rw-r--r--synapse/handlers/room_member.py25
-rw-r--r--synapse/handlers/saml_handler.py7
-rw-r--r--synapse/handlers/search.py38
-rw-r--r--synapse/handlers/set_password.py10
-rw-r--r--synapse/handlers/sso.py220
-rw-r--r--synapse/handlers/state_deltas.py14
-rw-r--r--synapse/handlers/stats.py39
-rw-r--r--synapse/handlers/typing.py69
-rw-r--r--synapse/handlers/user_directory.py9
23 files changed, 750 insertions, 343 deletions
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 8476256a59..5ecb2da1ac 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 import twisted
 import twisted.internet.error
@@ -22,6 +23,9 @@ from twisted.web.resource import Resource
 
 from synapse.app import check_bind_error
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 ACME_REGISTER_FAIL_ERROR = """
@@ -35,12 +39,12 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC
 
 
 class AcmeHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.reactor = hs.get_reactor()
         self._acme_domain = hs.config.acme_domain
 
-    async def start_listening(self):
+    async def start_listening(self) -> None:
         from synapse.handlers import acme_issuing_service
 
         # Configure logging for txacme, if you need to debug
@@ -85,7 +89,7 @@ class AcmeHandler:
             logger.error(ACME_REGISTER_FAIL_ERROR)
             raise
 
-    async def provision_certificate(self):
+    async def provision_certificate(self) -> None:
 
         logger.warning("Reprovisioning %s", self._acme_domain)
 
@@ -110,5 +114,3 @@ class AcmeHandler:
         except Exception:
             logger.exception("Failed saving!")
             raise
-
-        return True
diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py
index 7294649d71..ae2a9dd9c2 100644
--- a/synapse/handlers/acme_issuing_service.py
+++ b/synapse/handlers/acme_issuing_service.py
@@ -22,8 +22,10 @@ only need (and may only have available) if we are doing ACME, so is designed to
 imported conditionally.
 """
 import logging
+from typing import Dict, Iterable, List
 
 import attr
+import pem
 from cryptography.hazmat.backends import default_backend
 from cryptography.hazmat.primitives import serialization
 from josepy import JWKRSA
@@ -36,20 +38,27 @@ from txacme.util import generate_private_key
 from zope.interface import implementer
 
 from twisted.internet import defer
+from twisted.internet.interfaces import IReactorTCP
 from twisted.python.filepath import FilePath
 from twisted.python.url import URL
+from twisted.web.resource import IResource
 
 logger = logging.getLogger(__name__)
 
 
-def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
+def create_issuing_service(
+    reactor: IReactorTCP,
+    acme_url: str,
+    account_key_file: str,
+    well_known_resource: IResource,
+) -> AcmeIssuingService:
     """Create an ACME issuing service, and attach it to a web Resource
 
     Args:
         reactor: twisted reactor
-        acme_url (str): URL to use to request certificates
-        account_key_file (str): where to store the account key
-        well_known_resource (twisted.web.IResource): web resource for .well-known.
+        acme_url: URL to use to request certificates
+        account_key_file: where to store the account key
+        well_known_resource: web resource for .well-known.
             we will attach a child resource for "acme-challenge".
 
     Returns:
@@ -83,18 +92,20 @@ class ErsatzStore:
     A store that only stores in memory.
     """
 
-    certs = attr.ib(default=attr.Factory(dict))
+    certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict))
 
-    def store(self, server_name, pem_objects):
+    def store(
+        self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject]
+    ) -> defer.Deferred:
         self.certs[server_name] = [o.as_bytes() for o in pem_objects]
         return defer.succeed(None)
 
 
-def load_or_create_client_key(key_file):
+def load_or_create_client_key(key_file: str) -> JWKRSA:
     """Load the ACME account key from a file, creating it if it does not exist.
 
     Args:
-        key_file (str): name of the file to use as the account key
+        key_file: name of the file to use as the account key
     """
     # this is based on txacme.endpoint.load_or_create_client_key, but doesn't
     # hardcode the 'client.key' filename
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 0e98db22b3..a19c556437 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -61,6 +61,7 @@ from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.module_api import ModuleApi
+from synapse.storage.roommember import ProfileInfo
 from synapse.types import JsonDict, Requester, UserID
 from synapse.util import stringutils as stringutils
 from synapse.util.async_helpers import maybe_awaitable
@@ -567,16 +568,6 @@ class AuthHandler(BaseHandler):
                         session.session_id, login_type, result
                     )
             except LoginError as e:
-                if login_type == LoginType.EMAIL_IDENTITY:
-                    # riot used to have a bug where it would request a new
-                    # validation token (thus sending a new email) each time it
-                    # got a 401 with a 'flows' field.
-                    # (https://github.com/vector-im/vector-web/issues/2447).
-                    #
-                    # Grandfather in the old behaviour for now to avoid
-                    # breaking old riot deployments.
-                    raise
-
                 # this step failed. Merge the error dict into the response
                 # so that the client can have another go.
                 errordict = e.error_dict()
@@ -1387,7 +1378,9 @@ class AuthHandler(BaseHandler):
         )
 
         return self._sso_auth_confirm_template.render(
-            description=session.description, redirect_url=redirect_url,
+            description=session.description,
+            redirect_url=redirect_url,
+            idp=sso_auth_provider,
         )
 
     async def complete_sso_login(
@@ -1396,6 +1389,7 @@ class AuthHandler(BaseHandler):
         request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
+        new_user: bool = False,
     ):
         """Having figured out a mxid for this user, complete the HTTP request
 
@@ -1406,6 +1400,8 @@ class AuthHandler(BaseHandler):
                 process.
             extra_attributes: Extra attributes which will be passed to the client
                 during successful login. Must be JSON serializable.
+            new_user: True if we should use wording appropriate to a user who has just
+                registered.
         """
         # If the account has been deactivated, do not proceed with the login
         # flow.
@@ -1414,8 +1410,17 @@ class AuthHandler(BaseHandler):
             respond_with_html(request, 403, self._sso_account_deactivated_template)
             return
 
+        profile = await self.store.get_profileinfo(
+            UserID.from_string(registered_user_id).localpart
+        )
+
         self._complete_sso_login(
-            registered_user_id, request, client_redirect_url, extra_attributes
+            registered_user_id,
+            request,
+            client_redirect_url,
+            extra_attributes,
+            new_user=new_user,
+            user_profile_data=profile,
         )
 
     def _complete_sso_login(
@@ -1424,12 +1429,18 @@ class AuthHandler(BaseHandler):
         request: Request,
         client_redirect_url: str,
         extra_attributes: Optional[JsonDict] = None,
+        new_user: bool = False,
+        user_profile_data: Optional[ProfileInfo] = None,
     ):
         """
         The synchronous portion of complete_sso_login.
 
         This exists purely for backwards compatibility of synapse.module_api.ModuleApi.
         """
+
+        if user_profile_data is None:
+            user_profile_data = ProfileInfo(None, None)
+
         # Store any extra attributes which will be passed in the login response.
         # Note that this is per-user so it may overwrite a previous value, this
         # is considered OK since the newest SSO attributes should be most valid.
@@ -1467,6 +1478,9 @@ class AuthHandler(BaseHandler):
             display_url=redirect_url_no_params,
             redirect_url=redirect_url,
             server_name=self._server_name,
+            new_user=new_user,
+            user_id=registered_user_id,
+            user_profile=user_profile_data,
         )
         respond_with_html(request, 200, html)
 
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index f3430c6713..bd35d1fb87 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -80,6 +80,11 @@ class CasHandler:
         # user-facing name of this auth provider
         self.idp_name = "CAS"
 
+        # we do not currently support brands/icons for CAS auth, but this is required by
+        # the SsoIdentityProvider protocol type.
+        self.idp_icon = None
+        self.idp_brand = None
+
         self._sso_handler = hs.get_sso_handler()
 
         self._sso_handler.register_identity_provider(self)
@@ -95,11 +100,7 @@ class CasHandler:
         Returns:
             The URL to use as a "service" parameter.
         """
-        return "%s%s?%s" % (
-            self._cas_service_url,
-            "/_matrix/client/r0/login/cas/ticket",
-            urllib.parse.urlencode(args),
-        )
+        return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),)
 
     async def _validate_ticket(
         self, ticket: str, service_args: Dict[str, str]
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index debb1b4f29..0863154f7a 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -15,7 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api import errors
 from synapse.api.constants import EventTypes
@@ -62,7 +62,7 @@ class DeviceWorkerHandler(BaseHandler):
         self._auth_handler = hs.get_auth_handler()
 
     @trace
-    async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]:
+    async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
         """
         Retrieve the given user's devices
 
@@ -85,7 +85,7 @@ class DeviceWorkerHandler(BaseHandler):
         return devices
 
     @trace
-    async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
+    async def get_device(self, user_id: str, device_id: str) -> JsonDict:
         """ Retrieve the given device
 
         Args:
@@ -598,7 +598,7 @@ class DeviceHandler(DeviceWorkerHandler):
 
 
 def _update_device_from_client_ips(
-    device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]]
+    device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict]
 ) -> None:
     ip = client_ips.get((device["user_id"], device["device_id"]), {})
     device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
@@ -946,8 +946,8 @@ class DeviceListUpdater:
     async def process_cross_signing_key_update(
         self,
         user_id: str,
-        master_key: Optional[Dict[str, Any]],
-        self_signing_key: Optional[Dict[str, Any]],
+        master_key: Optional[JsonDict],
+        self_signing_key: Optional[JsonDict],
     ) -> List[str]:
         """Process the given new master and self-signing key for the given remote user.
 
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 929752150d..8f3a6b35a4 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -16,7 +16,7 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
 
 import attr
 from canonicaljson import encode_canonical_json
@@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
 from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
 from synapse.types import (
+    JsonDict,
     UserID,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
@@ -40,11 +41,14 @@ from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class E2eKeysHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.federation = hs.get_federation_client()
         self.device_handler = hs.get_device_handler()
@@ -78,7 +82,9 @@ class E2eKeysHandler:
         )
 
     @trace
-    async def query_devices(self, query_body, timeout, from_user_id):
+    async def query_devices(
+        self, query_body: JsonDict, timeout: int, from_user_id: str
+    ) -> JsonDict:
         """ Handle a device key query from a client
 
         {
@@ -98,12 +104,14 @@ class E2eKeysHandler:
         }
 
         Args:
-            from_user_id (str): the user making the query.  This is used when
+            from_user_id: the user making the query.  This is used when
                 adding cross-signing signatures to limit what signatures users
                 can see.
         """
 
-        device_keys_query = query_body.get("device_keys", {})
+        device_keys_query = query_body.get(
+            "device_keys", {}
+        )  # type: Dict[str, Iterable[str]]
 
         # separate users by domain.
         # make a map from domain to user_id to device_ids
@@ -121,7 +129,8 @@ class E2eKeysHandler:
         set_tag("remote_key_query", remote_queries)
 
         # First get local devices.
-        failures = {}
+        # A map of destination -> failure response.
+        failures = {}  # type: Dict[str, JsonDict]
         results = {}
         if local_query:
             local_result = await self.query_local_devices(local_query)
@@ -135,9 +144,10 @@ class E2eKeysHandler:
         )
 
         # Now attempt to get any remote devices from our local cache.
-        remote_queries_not_in_cache = {}
+        # A map of destination -> user ID -> device IDs.
+        remote_queries_not_in_cache = {}  # type: Dict[str, Dict[str, Iterable[str]]]
         if remote_queries:
-            query_list = []
+            query_list = []  # type: List[Tuple[str, Optional[str]]]
             for user_id, device_ids in remote_queries.items():
                 if device_ids:
                     query_list.extend((user_id, device_id) for device_id in device_ids)
@@ -284,15 +294,15 @@ class E2eKeysHandler:
         return ret
 
     async def get_cross_signing_keys_from_cache(
-        self, query, from_user_id
+        self, query: Iterable[str], from_user_id: Optional[str]
     ) -> Dict[str, Dict[str, dict]]:
         """Get cross-signing keys for users from the database
 
         Args:
-            query (Iterable[string]) an iterable of user IDs.  A dict whose keys
+            query: an iterable of user IDs.  A dict whose keys
                 are user IDs satisfies this, so the query format used for
                 query_devices can be used here.
-            from_user_id (str): the user making the query.  This is used when
+            from_user_id: the user making the query.  This is used when
                 adding cross-signing signatures to limit what signatures users
                 can see.
 
@@ -315,14 +325,12 @@ class E2eKeysHandler:
             if "self_signing" in user_info:
                 self_signing_keys[user_id] = user_info["self_signing"]
 
-        if (
-            from_user_id in keys
-            and keys[from_user_id] is not None
-            and "user_signing" in keys[from_user_id]
-        ):
-            # users can see other users' master and self-signing keys, but can
-            # only see their own user-signing keys
-            user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
+        # users can see other users' master and self-signing keys, but can
+        # only see their own user-signing keys
+        if from_user_id:
+            from_user_key = keys.get(from_user_id)
+            if from_user_key and "user_signing" in from_user_key:
+                user_signing_keys[from_user_id] = from_user_key["user_signing"]
 
         return {
             "master_keys": master_keys,
@@ -344,9 +352,9 @@ class E2eKeysHandler:
             A map from user_id -> device_id -> device details
         """
         set_tag("local_query", query)
-        local_query = []
+        local_query = []  # type: List[Tuple[str, Optional[str]]]
 
-        result_dict = {}
+        result_dict = {}  # type: Dict[str, Dict[str, dict]]
         for user_id, device_ids in query.items():
             # we use UserID.from_string to catch invalid user ids
             if not self.is_mine(UserID.from_string(user_id)):
@@ -380,10 +388,14 @@ class E2eKeysHandler:
         log_kv(results)
         return result_dict
 
-    async def on_federation_query_client_keys(self, query_body):
+    async def on_federation_query_client_keys(
+        self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
+    ) -> JsonDict:
         """ Handle a device key query from a federated server
         """
-        device_keys_query = query_body.get("device_keys", {})
+        device_keys_query = query_body.get(
+            "device_keys", {}
+        )  # type: Dict[str, Optional[List[str]]]
         res = await self.query_local_devices(device_keys_query)
         ret = {"device_keys": res}
 
@@ -397,31 +409,34 @@ class E2eKeysHandler:
         return ret
 
     @trace
-    async def claim_one_time_keys(self, query, timeout):
-        local_query = []
-        remote_queries = {}
+    async def claim_one_time_keys(
+        self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
+    ) -> JsonDict:
+        local_query = []  # type: List[Tuple[str, str, str]]
+        remote_queries = {}  # type: Dict[str, Dict[str, Dict[str, str]]]
 
-        for user_id, device_keys in query.get("one_time_keys", {}).items():
+        for user_id, one_time_keys in query.get("one_time_keys", {}).items():
             # we use UserID.from_string to catch invalid user ids
             if self.is_mine(UserID.from_string(user_id)):
-                for device_id, algorithm in device_keys.items():
+                for device_id, algorithm in one_time_keys.items():
                     local_query.append((user_id, device_id, algorithm))
             else:
                 domain = get_domain_from_id(user_id)
-                remote_queries.setdefault(domain, {})[user_id] = device_keys
+                remote_queries.setdefault(domain, {})[user_id] = one_time_keys
 
         set_tag("local_key_query", local_query)
         set_tag("remote_key_query", remote_queries)
 
         results = await self.store.claim_e2e_one_time_keys(local_query)
 
-        json_result = {}
-        failures = {}
+        # A map of user ID -> device ID -> key ID -> key.
+        json_result = {}  # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
+        failures = {}  # type: Dict[str, JsonDict]
         for user_id, device_keys in results.items():
             for device_id, keys in device_keys.items():
-                for key_id, json_bytes in keys.items():
+                for key_id, json_str in keys.items():
                     json_result.setdefault(user_id, {})[device_id] = {
-                        key_id: json_decoder.decode(json_bytes)
+                        key_id: json_decoder.decode(json_str)
                     }
 
         @trace
@@ -468,7 +483,9 @@ class E2eKeysHandler:
         return {"one_time_keys": json_result, "failures": failures}
 
     @tag_args
-    async def upload_keys_for_user(self, user_id, device_id, keys):
+    async def upload_keys_for_user(
+        self, user_id: str, device_id: str, keys: JsonDict
+    ) -> JsonDict:
 
         time_now = self.clock.time_msec()
 
@@ -543,8 +560,8 @@ class E2eKeysHandler:
         return {"one_time_key_counts": result}
 
     async def _upload_one_time_keys_for_user(
-        self, user_id, device_id, time_now, one_time_keys
-    ):
+        self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
+    ) -> None:
         logger.info(
             "Adding one_time_keys %r for device %r for user %r at %d",
             one_time_keys.keys(),
@@ -585,12 +602,14 @@ class E2eKeysHandler:
         log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
         await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
 
-    async def upload_signing_keys_for_user(self, user_id, keys):
+    async def upload_signing_keys_for_user(
+        self, user_id: str, keys: JsonDict
+    ) -> JsonDict:
         """Upload signing keys for cross-signing
 
         Args:
-            user_id (string): the user uploading the keys
-            keys (dict[string, dict]): the signing keys
+            user_id: the user uploading the keys
+            keys: the signing keys
         """
 
         # if a master key is uploaded, then check it.  Otherwise, load the
@@ -667,16 +686,17 @@ class E2eKeysHandler:
 
         return {}
 
-    async def upload_signatures_for_device_keys(self, user_id, signatures):
+    async def upload_signatures_for_device_keys(
+        self, user_id: str, signatures: JsonDict
+    ) -> JsonDict:
         """Upload device signatures for cross-signing
 
         Args:
-            user_id (string): the user uploading the signatures
-            signatures (dict[string, dict[string, dict]]): map of users to
-                devices to signed keys. This is the submission from the user; an
-                exception will be raised if it is malformed.
+            user_id: the user uploading the signatures
+            signatures: map of users to devices to signed keys. This is the submission
+            from the user; an exception will be raised if it is malformed.
         Returns:
-            dict: response to be sent back to the client.  The response will have
+            The response to be sent back to the client.  The response will have
                 a "failures" key, which will be a dict mapping users to devices
                 to errors for the signatures that failed.
         Raises:
@@ -719,7 +739,9 @@ class E2eKeysHandler:
 
         return {"failures": failures}
 
-    async def _process_self_signatures(self, user_id, signatures):
+    async def _process_self_signatures(
+        self, user_id: str, signatures: JsonDict
+    ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
         """Process uploaded signatures of the user's own keys.
 
         Signatures of the user's own keys from this API come in two forms:
@@ -731,15 +753,14 @@ class E2eKeysHandler:
             signatures (dict[string, dict]): map of devices to signed keys
 
         Returns:
-            (list[SignatureListItem], dict[string, dict[string, dict]]):
-            a list of signatures to store, and a map of users to devices to failure
-            reasons
+            A tuple of a list of signatures to store, and a map of users to
+            devices to failure reasons
 
         Raises:
             SynapseError: if the input is malformed
         """
-        signature_list = []
-        failures = {}
+        signature_list = []  # type: List[SignatureListItem]
+        failures = {}  # type: Dict[str, Dict[str, JsonDict]]
         if not signatures:
             return signature_list, failures
 
@@ -834,19 +855,24 @@ class E2eKeysHandler:
         return signature_list, failures
 
     def _check_master_key_signature(
-        self, user_id, master_key_id, signed_master_key, stored_master_key, devices
-    ):
+        self,
+        user_id: str,
+        master_key_id: str,
+        signed_master_key: JsonDict,
+        stored_master_key: JsonDict,
+        devices: Dict[str, Dict[str, JsonDict]],
+    ) -> List["SignatureListItem"]:
         """Check signatures of a user's master key made by their devices.
 
         Args:
-            user_id (string): the user whose master key is being checked
-            master_key_id (string): the ID of the user's master key
-            signed_master_key (dict): the user's signed master key that was uploaded
-            stored_master_key (dict): our previously-stored copy of the user's master key
-            devices (iterable(dict)): the user's devices
+            user_id: the user whose master key is being checked
+            master_key_id: the ID of the user's master key
+            signed_master_key: the user's signed master key that was uploaded
+            stored_master_key: our previously-stored copy of the user's master key
+            devices: the user's devices
 
         Returns:
-            list[SignatureListItem]: a list of signatures to store
+            A list of signatures to store
 
         Raises:
             SynapseError: if a signature is invalid
@@ -877,25 +903,26 @@ class E2eKeysHandler:
 
         return master_key_signature_list
 
-    async def _process_other_signatures(self, user_id, signatures):
+    async def _process_other_signatures(
+        self, user_id: str, signatures: Dict[str, dict]
+    ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
         """Process uploaded signatures of other users' keys.  These will be the
         target user's master keys, signed by the uploading user's user-signing
         key.
 
         Args:
-            user_id (string): the user uploading the keys
-            signatures (dict[string, dict]): map of users to devices to signed keys
+            user_id: the user uploading the keys
+            signatures: map of users to devices to signed keys
 
         Returns:
-            (list[SignatureListItem], dict[string, dict[string, dict]]):
-            a list of signatures to store, and a map of users to devices to failure
+            A list of signatures to store, and a map of users to devices to failure
             reasons
 
         Raises:
             SynapseError: if the input is malformed
         """
-        signature_list = []
-        failures = {}
+        signature_list = []  # type: List[SignatureListItem]
+        failures = {}  # type: Dict[str, Dict[str, JsonDict]]
         if not signatures:
             return signature_list, failures
 
@@ -983,7 +1010,7 @@ class E2eKeysHandler:
 
     async def _get_e2e_cross_signing_verify_key(
         self, user_id: str, key_type: str, from_user_id: str = None
-    ):
+    ) -> Tuple[JsonDict, str, VerifyKey]:
         """Fetch locally or remotely query for a cross-signing public key.
 
         First, attempt to fetch the cross-signing public key from storage.
@@ -997,8 +1024,7 @@ class E2eKeysHandler:
                 This affects what signatures are fetched.
 
         Returns:
-            dict, str, VerifyKey: the raw key data, the key ID, and the
-                signedjson verify key
+            The raw key data, the key ID, and the signedjson verify key
 
         Raises:
             NotFoundError: if the key is not found
@@ -1135,16 +1161,18 @@ class E2eKeysHandler:
         return desired_key, desired_key_id, desired_verify_key
 
 
-def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
+def _check_cross_signing_key(
+    key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None
+) -> None:
     """Check a cross-signing key uploaded by a user.  Performs some basic sanity
     checking, and ensures that it is signed, if a signature is required.
 
     Args:
-        key (dict): the key data to verify
-        user_id (str): the user whose key is being checked
-        key_type (str): the type of key that the key should be
-        signing_key (VerifyKey): (optional) the signing key that the key should
-            be signed with.  If omitted, signatures will not be checked.
+        key: the key data to verify
+        user_id: the user whose key is being checked
+        key_type: the type of key that the key should be
+        signing_key: the signing key that the key should be signed with.  If
+            omitted, signatures will not be checked.
     """
     if (
         key.get("user_id") != user_id
@@ -1162,16 +1190,21 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
             )
 
 
-def _check_device_signature(user_id, verify_key, signed_device, stored_device):
+def _check_device_signature(
+    user_id: str,
+    verify_key: VerifyKey,
+    signed_device: JsonDict,
+    stored_device: JsonDict,
+) -> None:
     """Check that a signature on a device or cross-signing key is correct and
     matches the copy of the device/key that we have stored.  Throws an
     exception if an error is detected.
 
     Args:
-        user_id (str): the user ID whose signature is being checked
-        verify_key (VerifyKey): the key to verify the device with
-        signed_device (dict): the uploaded signed device data
-        stored_device (dict): our previously stored copy of the device
+        user_id: the user ID whose signature is being checked
+        verify_key: the key to verify the device with
+        signed_device: the uploaded signed device data
+        stored_device: our previously stored copy of the device
 
     Raises:
         SynapseError: if the signature was invalid or the sent device is not the
@@ -1201,7 +1234,7 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device):
         raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE)
 
 
-def _exception_to_failure(e):
+def _exception_to_failure(e: Exception) -> JsonDict:
     if isinstance(e, SynapseError):
         return {"status": e.code, "errcode": e.errcode, "message": str(e)}
 
@@ -1218,7 +1251,7 @@ def _exception_to_failure(e):
     return {"status": 503, "message": str(e)}
 
 
-def _one_time_keys_match(old_key_json, new_key):
+def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
     old_key = json_decoder.decode(old_key_json)
 
     # if either is a string rather than an object, they must match exactly
@@ -1239,16 +1272,16 @@ class SignatureListItem:
     """An item in the signature list as used by upload_signatures_for_device_keys.
     """
 
-    signing_key_id = attr.ib()
-    target_user_id = attr.ib()
-    target_device_id = attr.ib()
-    signature = attr.ib()
+    signing_key_id = attr.ib(type=str)
+    target_user_id = attr.ib(type=str)
+    target_device_id = attr.ib(type=str)
+    signature = attr.ib(type=JsonDict)
 
 
 class SigningKeyEduUpdater:
     """Handles incoming signing key updates from federation and updates the DB"""
 
-    def __init__(self, hs, e2e_keys_handler):
+    def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
         self.store = hs.get_datastore()
         self.federation = hs.get_federation_client()
         self.clock = hs.get_clock()
@@ -1257,7 +1290,7 @@ class SigningKeyEduUpdater:
         self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
 
         # user_id -> list of updates waiting to be handled.
-        self._pending_updates = {}
+        self._pending_updates = {}  # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
 
         # Recently seen stream ids. We don't bother keeping these in the DB,
         # but they're useful to have them about to reduce the number of spurious
@@ -1270,13 +1303,15 @@ class SigningKeyEduUpdater:
             iterable=True,
         )
 
-    async def incoming_signing_key_update(self, origin, edu_content):
+    async def incoming_signing_key_update(
+        self, origin: str, edu_content: JsonDict
+    ) -> None:
         """Called on incoming signing key update from federation. Responsible for
         parsing the EDU and adding to pending updates list.
 
         Args:
-            origin (string): the server that sent the EDU
-            edu_content (dict): the contents of the EDU
+            origin: the server that sent the EDU
+            edu_content: the contents of the EDU
         """
 
         user_id = edu_content.pop("user_id")
@@ -1299,11 +1334,11 @@ class SigningKeyEduUpdater:
 
         await self._handle_signing_key_updates(user_id)
 
-    async def _handle_signing_key_updates(self, user_id):
+    async def _handle_signing_key_updates(self, user_id: str) -> None:
         """Actually handle pending updates.
 
         Args:
-            user_id (string): the user whose updates we are processing
+            user_id: the user whose updates we are processing
         """
 
         device_handler = self.e2e_keys_handler.device_handler
@@ -1315,7 +1350,7 @@ class SigningKeyEduUpdater:
                 # This can happen since we batch updates
                 return
 
-            device_ids = []
+            device_ids = []  # type: List[str]
 
             logger.info("pending updates: %r", pending_updates)
 
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index f01b090772..622cae23be 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, List, Optional
 
 from synapse.api.errors import (
     Codes,
@@ -24,8 +25,12 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.logging.opentracing import log_kv, trace
+from synapse.types import JsonDict
 from synapse.util.async_helpers import Linearizer
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -37,7 +42,7 @@ class E2eRoomKeysHandler:
     The actual payload of the encrypted keys is completely opaque to the handler.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
 
         # Used to lock whenever a client is uploading key data.  This prevents collisions
@@ -48,21 +53,27 @@ class E2eRoomKeysHandler:
         self._upload_linearizer = Linearizer("upload_room_keys_lock")
 
     @trace
-    async def get_room_keys(self, user_id, version, room_id=None, session_id=None):
+    async def get_room_keys(
+        self,
+        user_id: str,
+        version: str,
+        room_id: Optional[str] = None,
+        session_id: Optional[str] = None,
+    ) -> List[JsonDict]:
         """Bulk get the E2E room keys for a given backup, optionally filtered to a given
         room, or a given session.
         See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
 
         Args:
-            user_id(str): the user whose keys we're getting
-            version(str): the version ID of the backup we're getting keys from
-            room_id(string): room ID to get keys for, for None to get keys for all rooms
-            session_id(string): session ID to get keys for, for None to get keys for all
+            user_id: the user whose keys we're getting
+            version: the version ID of the backup we're getting keys from
+            room_id: room ID to get keys for, for None to get keys for all rooms
+            session_id: session ID to get keys for, for None to get keys for all
                 sessions
         Raises:
             NotFoundError: if the backup version does not exist
         Returns:
-            A deferred list of dicts giving the session_data and message metadata for
+            A list of dicts giving the session_data and message metadata for
             these room keys.
         """
 
@@ -86,17 +97,23 @@ class E2eRoomKeysHandler:
             return results
 
     @trace
-    async def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
+    async def delete_room_keys(
+        self,
+        user_id: str,
+        version: str,
+        room_id: Optional[str] = None,
+        session_id: Optional[str] = None,
+    ) -> JsonDict:
         """Bulk delete the E2E room keys for a given backup, optionally filtered to a given
         room or a given session.
         See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.
 
         Args:
-            user_id(str): the user whose backup we're deleting
-            version(str): the version ID of the backup we're deleting
-            room_id(string): room ID to delete keys for, for None to delete keys for all
+            user_id: the user whose backup we're deleting
+            version: the version ID of the backup we're deleting
+            room_id: room ID to delete keys for, for None to delete keys for all
                 rooms
-            session_id(string): session ID to delete keys for, for None to delete keys
+            session_id: session ID to delete keys for, for None to delete keys
                 for all sessions
         Raises:
             NotFoundError: if the backup version does not exist
@@ -128,15 +145,17 @@ class E2eRoomKeysHandler:
             return {"etag": str(version_etag), "count": count}
 
     @trace
-    async def upload_room_keys(self, user_id, version, room_keys):
+    async def upload_room_keys(
+        self, user_id: str, version: str, room_keys: JsonDict
+    ) -> JsonDict:
         """Bulk upload a list of room keys into a given backup version, asserting
         that the given version is the current backup version.  room_keys are merged
         into the current backup as described in RoomKeysServlet.on_PUT().
 
         Args:
-            user_id(str): the user whose backup we're setting
-            version(str): the version ID of the backup we're updating
-            room_keys(dict): a nested dict describing the room_keys we're setting:
+            user_id: the user whose backup we're setting
+            version: the version ID of the backup we're updating
+            room_keys: a nested dict describing the room_keys we're setting:
 
         {
             "rooms": {
@@ -254,14 +273,16 @@ class E2eRoomKeysHandler:
             return {"etag": str(version_etag), "count": count}
 
     @staticmethod
-    def _should_replace_room_key(current_room_key, room_key):
+    def _should_replace_room_key(
+        current_room_key: Optional[JsonDict], room_key: JsonDict
+    ) -> bool:
         """
         Determine whether to replace a given current_room_key (if any)
         with a newly uploaded room_key backup
 
         Args:
-            current_room_key (dict): Optional, the current room_key dict if any
-            room_key (dict): The new room_key dict which may or may not be fit to
+            current_room_key: Optional, the current room_key dict if any
+            room_key : The new room_key dict which may or may not be fit to
                 replace the current_room_key
 
         Returns:
@@ -286,14 +307,14 @@ class E2eRoomKeysHandler:
         return True
 
     @trace
-    async def create_version(self, user_id, version_info):
+    async def create_version(self, user_id: str, version_info: JsonDict) -> str:
         """Create a new backup version.  This automatically becomes the new
         backup version for the user's keys; previous backups will no longer be
         writeable to.
 
         Args:
-            user_id(str): the user whose backup version we're creating
-            version_info(dict): metadata about the new version being created
+            user_id: the user whose backup version we're creating
+            version_info: metadata about the new version being created
 
         {
             "algorithm": "m.megolm_backup.v1",
@@ -301,7 +322,7 @@ class E2eRoomKeysHandler:
         }
 
         Returns:
-            A deferred of a string that gives the new version number.
+            The new version number.
         """
 
         # TODO: Validate the JSON to make sure it has the right keys.
@@ -313,17 +334,19 @@ class E2eRoomKeysHandler:
             )
             return new_version
 
-    async def get_version_info(self, user_id, version=None):
+    async def get_version_info(
+        self, user_id: str, version: Optional[str] = None
+    ) -> JsonDict:
         """Get the info about a given version of the user's backup
 
         Args:
-            user_id(str): the user whose current backup version we're querying
-            version(str): Optional; if None gives the most recent version
+            user_id: the user whose current backup version we're querying
+            version: Optional; if None gives the most recent version
                 otherwise a historical one.
         Raises:
             NotFoundError: if the requested backup version doesn't exist
         Returns:
-            A deferred of a info dict that gives the info about the new version.
+            A info dict that gives the info about the new version.
 
         {
             "version": "1234",
@@ -346,7 +369,7 @@ class E2eRoomKeysHandler:
             return res
 
     @trace
-    async def delete_version(self, user_id, version=None):
+    async def delete_version(self, user_id: str, version: Optional[str] = None) -> None:
         """Deletes a given version of the user's e2e_room_keys backup
 
         Args:
@@ -366,17 +389,19 @@ class E2eRoomKeysHandler:
                     raise
 
     @trace
-    async def update_version(self, user_id, version, version_info):
+    async def update_version(
+        self, user_id: str, version: str, version_info: JsonDict
+    ) -> JsonDict:
         """Update the info about a given version of the user's backup
 
         Args:
-            user_id(str): the user whose current backup version we're updating
-            version(str): the backup version we're updating
-            version_info(dict): the new information about the backup
+            user_id: the user whose current backup version we're updating
+            version: the backup version we're updating
+            version_info: the new information about the backup
         Raises:
             NotFoundError: if the requested backup version doesn't exist
         Returns:
-            A deferred of an empty dict.
+            An empty dict.
         """
         if "version" not in version_info:
             version_info["version"] = version
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index fd8de8696d..dbdfd56ff5 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1617,6 +1617,10 @@ class FederationHandler(BaseHandler):
         if event.state_key == self._server_notices_mxid:
             raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
 
+        # We retrieve the room member handler here as to not cause a cyclic dependency
+        member_handler = self.hs.get_room_member_handler()
+        member_handler.ratelimit_invite(event.room_id, event.state_key)
+
         # keep a record of the room version, if we don't yet know it.
         # (this may get overwritten if we later get a different room version in a
         # join dance).
@@ -2093,6 +2097,11 @@ class FederationHandler(BaseHandler):
         if event.type == EventTypes.GuestAccess and not context.rejected:
             await self.maybe_kick_guest_users(event)
 
+        # If we are going to send this event over federation we precaclculate
+        # the joined hosts.
+        if event.internal_metadata.get_send_on_behalf_of():
+            await self.event_creation_handler.cache_joined_hosts_for_event(event)
+
         return context
 
     async def _check_for_soft_fail(
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index a2f16f77df..73eb7db633 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -15,9 +15,13 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set
 
 from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
-from synapse.types import GroupID, get_domain_from_id
+from synapse.types import GroupID, JsonDict, get_domain_from_id
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -56,7 +60,7 @@ def _create_rerouter(func_name):
 
 
 class GroupsLocalWorkerHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.room_list_handler = hs.get_room_list_handler()
@@ -84,7 +88,9 @@ class GroupsLocalWorkerHandler:
     get_group_role = _create_rerouter("get_group_role")
     get_group_roles = _create_rerouter("get_group_roles")
 
-    async def get_group_summary(self, group_id, requester_user_id):
+    async def get_group_summary(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get the group summary for a group.
 
         If the group is remote we check that the users have valid attestations.
@@ -137,14 +143,15 @@ class GroupsLocalWorkerHandler:
 
         return res
 
-    async def get_users_in_group(self, group_id, requester_user_id):
+    async def get_users_in_group(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get users in a group
         """
         if self.is_mine_id(group_id):
-            res = await self.groups_server_handler.get_users_in_group(
+            return await self.groups_server_handler.get_users_in_group(
                 group_id, requester_user_id
             )
-            return res
 
         group_server_name = get_domain_from_id(group_id)
 
@@ -178,11 +185,11 @@ class GroupsLocalWorkerHandler:
 
         return res
 
-    async def get_joined_groups(self, user_id):
+    async def get_joined_groups(self, user_id: str) -> JsonDict:
         group_ids = await self.store.get_joined_groups(user_id)
         return {"groups": group_ids}
 
-    async def get_publicised_groups_for_user(self, user_id):
+    async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict:
         if self.hs.is_mine_id(user_id):
             result = await self.store.get_publicised_groups_for_user(user_id)
 
@@ -206,8 +213,10 @@ class GroupsLocalWorkerHandler:
             # TODO: Verify attestations
             return {"groups": result}
 
-    async def bulk_get_publicised_groups(self, user_ids, proxy=True):
-        destinations = {}
+    async def bulk_get_publicised_groups(
+        self, user_ids: Iterable[str], proxy: bool = True
+    ) -> JsonDict:
+        destinations = {}  # type: Dict[str, Set[str]]
         local_users = set()
 
         for user_id in user_ids:
@@ -220,7 +229,7 @@ class GroupsLocalWorkerHandler:
             raise SynapseError(400, "Some user_ids are not local")
 
         results = {}
-        failed_results = []
+        failed_results = []  # type: List[str]
         for destination, dest_user_ids in destinations.items():
             try:
                 r = await self.transport_client.bulk_get_publicised_groups(
@@ -242,7 +251,7 @@ class GroupsLocalWorkerHandler:
 
 
 class GroupsLocalHandler(GroupsLocalWorkerHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         # Ensure attestations get renewed
@@ -271,7 +280,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
     set_group_join_policy = _create_rerouter("set_group_join_policy")
 
-    async def create_group(self, group_id, user_id, content):
+    async def create_group(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Create a group
         """
 
@@ -284,27 +295,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
             local_attestation = None
             remote_attestation = None
         else:
-            local_attestation = self.attestations.create_attestation(group_id, user_id)
-            content["attestation"] = local_attestation
-
-            content["user_profile"] = await self.profile_handler.get_profile(user_id)
-
-            try:
-                res = await self.transport_client.create_group(
-                    get_domain_from_id(group_id), group_id, user_id, content
-                )
-            except HttpResponseException as e:
-                raise e.to_synapse_error()
-            except RequestSendFailed:
-                raise SynapseError(502, "Failed to contact group server")
-
-            remote_attestation = res["attestation"]
-            await self.attestations.verify_attestation(
-                remote_attestation,
-                group_id=group_id,
-                user_id=user_id,
-                server_name=get_domain_from_id(group_id),
-            )
+            raise SynapseError(400, "Unable to create remote groups")
 
         is_publicised = content.get("publicise", False)
         token = await self.store.register_user_group_membership(
@@ -320,7 +311,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return res
 
-    async def join_group(self, group_id, user_id, content):
+    async def join_group(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Request to join a group
         """
         if self.is_mine_id(group_id):
@@ -391,7 +384,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return {}
 
-    async def accept_invite(self, group_id, user_id, content):
+    async def accept_invite(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Accept an invite to a group
         """
         if self.is_mine_id(group_id):
@@ -436,7 +431,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return {}
 
-    async def invite(self, group_id, user_id, requester_user_id, config):
+    async def invite(
+        self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
+    ) -> JsonDict:
         """Invite a user to a group
         """
         content = {"requester_user_id": requester_user_id, "config": config}
@@ -460,7 +457,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return res
 
-    async def on_invite(self, group_id, user_id, content):
+    async def on_invite(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """One of our users were invited to a group
         """
         # TODO: Support auto join and rejection
@@ -491,8 +490,8 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
         return {"state": "invite", "user_profile": user_profile}
 
     async def remove_user_from_group(
-        self, group_id, user_id, requester_user_id, content
-    ):
+        self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Remove a user from a group
         """
         if user_id == requester_user_id:
@@ -525,7 +524,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
 
         return res
 
-    async def user_removed_from_group(self, group_id, user_id, content):
+    async def user_removed_from_group(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> None:
         """One of our users was removed/kicked from a group
         """
         # TODO: Check if user in group
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index c05036ad1f..4f7137539b 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -27,9 +27,11 @@ from synapse.api.errors import (
     HttpResponseException,
     SynapseError,
 )
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.http import RequestTimedOutError
 from synapse.http.client import SimpleHttpClient
+from synapse.http.site import SynapseRequest
 from synapse.types import JsonDict, Requester
 from synapse.util import json_decoder
 from synapse.util.hash import sha256_and_url_safe_base64
@@ -57,6 +59,32 @@ class IdentityHandler(BaseHandler):
 
         self._web_client_location = hs.config.invite_client_location
 
+        # Ratelimiters for `/requestToken` endpoints.
+        self._3pid_validation_ratelimiter_ip = Ratelimiter(
+            clock=hs.get_clock(),
+            rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
+            burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
+        )
+        self._3pid_validation_ratelimiter_address = Ratelimiter(
+            clock=hs.get_clock(),
+            rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
+            burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
+        )
+
+    def ratelimit_request_token_requests(
+        self, request: SynapseRequest, medium: str, address: str,
+    ):
+        """Used to ratelimit requests to `/requestToken` by IP and address.
+
+        Args:
+            request: The associated request
+            medium: The type of threepid, e.g. "msisdn" or "email"
+            address: The actual threepid ID, e.g. the phone number or email address
+        """
+
+        self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
+        self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
+
     async def threepid_from_creds(
         self, id_server: str, creds: Dict[str, str]
     ) -> Optional[JsonDict]:
@@ -476,8 +504,6 @@ class IdentityHandler(BaseHandler):
         except RequestTimedOutError:
             raise SynapseError(500, "Timed out contacting identity server")
 
-        assert self.hs.config.public_baseurl
-
         # we need to tell the client to send the token back to us, since it doesn't
         # otherwise know where to send it, so add submit_url response parameter
         # (see also MSC2078)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9dfeab09cd..a15336bf00 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -174,7 +174,7 @@ class MessageHandler:
                 raise NotFoundError("Can't find event for token %s" % (at_token,))
 
             visible_events = await filter_events_for_client(
-                self.storage, user_id, last_events, filter_send_to_client=False
+                self.storage, user_id, last_events, filter_send_to_client=False,
             )
 
             event = last_events[0]
@@ -432,6 +432,8 @@ class EventCreationHandler:
 
         self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
 
+        self._external_cache = hs.get_external_cache()
+
     async def create_event(
         self,
         requester: Requester,
@@ -939,6 +941,8 @@ class EventCreationHandler:
 
         await self.action_generator.handle_push_actions_for_event(event, context)
 
+        await self.cache_joined_hosts_for_event(event)
+
         try:
             # If we're a worker we need to hit out to the master.
             writer_instance = self._events_shard_config.get_instance(event.room_id)
@@ -978,6 +982,44 @@ class EventCreationHandler:
             await self.store.remove_push_actions_from_staging(event.event_id)
             raise
 
+    async def cache_joined_hosts_for_event(self, event: EventBase) -> None:
+        """Precalculate the joined hosts at the event, when using Redis, so that
+        external federation senders don't have to recalculate it themselves.
+        """
+
+        if not self._external_cache.is_enabled():
+            return
+
+        # We actually store two mappings, event ID -> prev state group,
+        # state group -> joined hosts, which is much more space efficient
+        # than event ID -> joined hosts.
+        #
+        # Note: We have to cache event ID -> prev state group, as we don't
+        # store that in the DB.
+        #
+        # Note: We always set the state group -> joined hosts cache, even if
+        # we already set it, so that the expiry time is reset.
+
+        state_entry = await self.state.resolve_state_groups_for_events(
+            event.room_id, event_ids=event.prev_event_ids()
+        )
+
+        if state_entry.state_group:
+            joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
+
+            await self._external_cache.set(
+                "event_to_prev_state_group",
+                event.event_id,
+                state_entry.state_group,
+                expiry_ms=60 * 60 * 1000,
+            )
+            await self._external_cache.set(
+                "get_joined_hosts",
+                str(state_entry.state_group),
+                list(joined_hosts),
+                expiry_ms=60 * 60 * 1000,
+            )
+
     async def _validate_canonical_alias(
         self, directory_handler, room_alias_str: str, expected_room_id: str
     ) -> None:
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index ba686d74b2..71008ec50d 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -102,7 +102,7 @@ class OidcHandler:
                 ) from e
 
     async def handle_oidc_callback(self, request: SynapseRequest) -> None:
-        """Handle an incoming request to /_synapse/oidc/callback
+        """Handle an incoming request to /_synapse/client/oidc/callback
 
         Since we might want to display OIDC-related errors in a user-friendly
         way, we don't raise SynapseError from here. Instead, we call
@@ -271,6 +271,12 @@ class OidcProvider:
         # user-facing name of this auth provider
         self.idp_name = provider.idp_name
 
+        # MXC URI for icon for this auth provider
+        self.idp_icon = provider.idp_icon
+
+        # optional brand identifier for this auth provider
+        self.idp_brand = provider.idp_brand
+
         self._sso_handler = hs.get_sso_handler()
 
         self._sso_handler.register_identity_provider(self)
@@ -637,7 +643,7 @@ class OidcProvider:
 
           - ``client_id``: the client ID set in ``oidc_config.client_id``
           - ``response_type``: ``code``
-          - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback``
+          - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/client/oidc/callback``
           - ``scope``: the list of scopes set in ``oidc_config.scopes``
           - ``state``: a random string
           - ``nonce``: a random string
@@ -678,7 +684,7 @@ class OidcProvider:
         request.addCookie(
             SESSION_COOKIE_NAME,
             cookie,
-            path="/_synapse/oidc",
+            path="/_synapse/client/oidc",
             max_age="3600",
             httpOnly=True,
             sameSite="lax",
@@ -699,7 +705,7 @@ class OidcProvider:
     async def handle_oidc_callback(
         self, request: SynapseRequest, session_data: "OidcSessionData", code: str
     ) -> None:
-        """Handle an incoming request to /_synapse/oidc/callback
+        """Handle an incoming request to /_synapse/client/oidc/callback
 
         By this time we have already validated the session on the synapse side, and
         now need to do the provider-specific operations. This includes:
@@ -1053,7 +1059,8 @@ class OidcSessionData:
 
 
 UserAttributeDict = TypedDict(
-    "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
+    "UserAttributeDict",
+    {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]},
 )
 C = TypeVar("C")
 
@@ -1132,11 +1139,12 @@ def jinja_finalize(thing):
 env = Environment(finalize=jinja_finalize)
 
 
-@attr.s
+@attr.s(slots=True, frozen=True)
 class JinjaOidcMappingConfig:
     subject_claim = attr.ib(type=str)
     localpart_template = attr.ib(type=Optional[Template])
     display_name_template = attr.ib(type=Optional[Template])
+    email_template = attr.ib(type=Optional[Template])
     extra_attributes = attr.ib(type=Dict[str, Template])
 
 
@@ -1153,23 +1161,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
     def parse_config(config: dict) -> JinjaOidcMappingConfig:
         subject_claim = config.get("subject_claim", "sub")
 
-        localpart_template = None  # type: Optional[Template]
-        if "localpart_template" in config:
+        def parse_template_config(option_name: str) -> Optional[Template]:
+            if option_name not in config:
+                return None
             try:
-                localpart_template = env.from_string(config["localpart_template"])
+                return env.from_string(config[option_name])
             except Exception as e:
-                raise ConfigError(
-                    "invalid jinja template", path=["localpart_template"]
-                ) from e
+                raise ConfigError("invalid jinja template", path=[option_name]) from e
 
-        display_name_template = None  # type: Optional[Template]
-        if "display_name_template" in config:
-            try:
-                display_name_template = env.from_string(config["display_name_template"])
-            except Exception as e:
-                raise ConfigError(
-                    "invalid jinja template", path=["display_name_template"]
-                ) from e
+        localpart_template = parse_template_config("localpart_template")
+        display_name_template = parse_template_config("display_name_template")
+        email_template = parse_template_config("email_template")
 
         extra_attributes = {}  # type Dict[str, Template]
         if "extra_attributes" in config:
@@ -1189,6 +1191,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
             subject_claim=subject_claim,
             localpart_template=localpart_template,
             display_name_template=display_name_template,
+            email_template=email_template,
             extra_attributes=extra_attributes,
         )
 
@@ -1210,16 +1213,23 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
             # a usable mxid.
             localpart += str(failures) if failures else ""
 
-        display_name = None  # type: Optional[str]
-        if self._config.display_name_template is not None:
-            display_name = self._config.display_name_template.render(
-                user=userinfo
-            ).strip()
+        def render_template_field(template: Optional[Template]) -> Optional[str]:
+            if template is None:
+                return None
+            return template.render(user=userinfo).strip()
 
-            if display_name == "":
-                display_name = None
+        display_name = render_template_field(self._config.display_name_template)
+        if display_name == "":
+            display_name = None
 
-        return UserAttributeDict(localpart=localpart, display_name=display_name)
+        emails = []  # type: List[str]
+        email = render_template_field(self._config.email_template)
+        if email:
+            emails.append(email)
+
+        return UserAttributeDict(
+            localpart=localpart, display_name=display_name, emails=emails
+        )
 
     async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
         extras = {}  # type: Dict[str, str]
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index a2cf0f6f3e..49b085269b 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -14,8 +14,9 @@
 # limitations under the License.
 
 """Contains functions for registering clients."""
+
 import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 
 from synapse import types
 from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
@@ -152,7 +153,7 @@ class RegistrationHandler(BaseHandler):
         user_type: Optional[str] = None,
         default_display_name: Optional[str] = None,
         address: Optional[str] = None,
-        bind_emails: List[str] = [],
+        bind_emails: Iterable[str] = [],
         by_admin: bool = False,
         user_agent_ips: Optional[List[Tuple[str, str]]] = None,
     ) -> str:
@@ -693,6 +694,8 @@ class RegistrationHandler(BaseHandler):
             access_token: The access token of the newly logged in device, or
                 None if `inhibit_login` enabled.
         """
+        # TODO: 3pid registration can actually happen on the workers. Consider
+        # refactoring it.
         if self.hs.config.worker_app:
             await self._post_registration_client(
                 user_id=user_id, auth_result=auth_result, access_token=access_token
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 3bece6d668..07b2187eb1 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -38,7 +38,6 @@ from synapse.api.filtering import Filter
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.events import EventBase
 from synapse.events.utils import copy_power_levels_contents
-from synapse.http.endpoint import parse_and_validate_server_name
 from synapse.storage.state import StateFilter
 from synapse.types import (
     JsonDict,
@@ -55,6 +54,7 @@ from synapse.types import (
 from synapse.util import stringutils
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.stringutils import parse_and_validate_server_name
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
@@ -126,6 +126,10 @@ class RoomCreationHandler(BaseHandler):
 
         self.third_party_event_rules = hs.get_third_party_event_rules()
 
+        self._invite_burst_count = (
+            hs.config.ratelimiting.rc_invites_per_room.burst_count
+        )
+
     async def upgrade_room(
         self, requester: Requester, old_room_id: str, new_version: RoomVersion
     ) -> str:
@@ -662,6 +666,9 @@ class RoomCreationHandler(BaseHandler):
             invite_3pid_list = []
             invite_list = []
 
+        if len(invite_list) + len(invite_3pid_list) > self._invite_burst_count:
+            raise SynapseError(400, "Cannot invite so many users at once")
+
         await self.event_creation_handler.assert_accepted_privacy_policy(requester)
 
         power_level_content_override = config.get("power_level_content_override")
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index e001e418f9..d335da6f19 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -85,6 +85,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
         )
 
+        self._invites_per_room_limiter = Ratelimiter(
+            clock=self.clock,
+            rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
+            burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
+        )
+        self._invites_per_user_limiter = Ratelimiter(
+            clock=self.clock,
+            rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
+            burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
+        )
+
         # This is only used to get at ratelimit function, and
         # maybe_kick_guest_users. It's fine there are multiple of these as
         # it doesn't store state.
@@ -144,6 +155,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         """
         raise NotImplementedError()
 
+    def ratelimit_invite(self, room_id: str, invitee_user_id: str):
+        """Ratelimit invites by room and by target user.
+        """
+        self._invites_per_room_limiter.ratelimit(room_id)
+        self._invites_per_user_limiter.ratelimit(invitee_user_id)
+
     async def _local_membership_update(
         self,
         requester: Requester,
@@ -387,8 +404,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 raise SynapseError(403, "This room has been blocked on this server")
 
         if effective_membership_state == Membership.INVITE:
+            target_id = target.to_string()
+            if ratelimit:
+                self.ratelimit_invite(room_id, target_id)
+
             # block any attempts to invite the server notices mxid
-            if target.to_string() == self._server_notices_mxid:
+            if target_id == self._server_notices_mxid:
                 raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
 
             block_invite = False
@@ -412,7 +433,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     block_invite = True
 
                 if not await self.spam_checker.user_may_invite(
-                    requester.user.to_string(), target.to_string(), room_id
+                    requester.user.to_string(), target_id, room_id
                 ):
                     logger.info("Blocking invite due to spam checker")
                     block_invite = True
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index a8376543c9..e88fd59749 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -78,6 +78,11 @@ class SamlHandler(BaseHandler):
         # user-facing name of this auth provider
         self.idp_name = "SAML"
 
+        # we do not currently support icons/brands for SAML auth, but this is required by
+        # the SsoIdentityProvider protocol type.
+        self.idp_icon = None
+        self.idp_brand = None
+
         # a map from saml session id to Saml2SessionData object
         self._outstanding_requests_dict = {}  # type: Dict[str, Saml2SessionData]
 
@@ -128,7 +133,7 @@ class SamlHandler(BaseHandler):
         raise Exception("prepare_for_authenticate didn't return a Location header")
 
     async def handle_saml_response(self, request: SynapseRequest) -> None:
-        """Handle an incoming request to /_matrix/saml2/authn_response
+        """Handle an incoming request to /_synapse/client/saml2/authn_response
 
         Args:
             request: the incoming request from the browser. We'll
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 66f1bbcfc4..94062e79cb 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -15,23 +15,28 @@
 
 import itertools
 import logging
-from typing import Iterable
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
 
 from unpaddedbase64 import decode_base64, encode_base64
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.filtering import Filter
+from synapse.events import EventBase
 from synapse.storage.state import StateFilter
+from synapse.types import JsonDict, UserID
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class SearchHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
@@ -87,13 +92,15 @@ class SearchHandler(BaseHandler):
 
         return historical_room_ids
 
-    async def search(self, user, content, batch=None):
+    async def search(
+        self, user: UserID, content: JsonDict, batch: Optional[str] = None
+    ) -> JsonDict:
         """Performs a full text search for a user.
 
         Args:
-            user (UserID)
-            content (dict): Search parameters
-            batch (str): The next_batch parameter. Used for pagination.
+            user
+            content: Search parameters
+            batch: The next_batch parameter. Used for pagination.
 
         Returns:
             dict to be returned to the client with results of search
@@ -186,7 +193,7 @@ class SearchHandler(BaseHandler):
         # If doing a subset of all rooms seearch, check if any of the rooms
         # are from an upgraded room, and search their contents as well
         if search_filter.rooms:
-            historical_room_ids = []
+            historical_room_ids = []  # type: List[str]
             for room_id in search_filter.rooms:
                 # Add any previous rooms to the search if they exist
                 ids = await self.get_old_rooms_from_upgraded_room(room_id)
@@ -209,8 +216,10 @@ class SearchHandler(BaseHandler):
 
         rank_map = {}  # event_id -> rank of event
         allowed_events = []
-        room_groups = {}  # Holds result of grouping by room, if applicable
-        sender_group = {}  # Holds result of grouping by sender, if applicable
+        # Holds result of grouping by room, if applicable
+        room_groups = {}  # type: Dict[str, JsonDict]
+        # Holds result of grouping by sender, if applicable
+        sender_group = {}  # type: Dict[str, JsonDict]
 
         # Holds the next_batch for the entire result set if one of those exists
         global_next_batch = None
@@ -254,7 +263,7 @@ class SearchHandler(BaseHandler):
                 s["results"].append(e.event_id)
 
         elif order_by == "recent":
-            room_events = []
+            room_events = []  # type: List[EventBase]
             i = 0
 
             pagination_token = batch_token
@@ -418,13 +427,10 @@ class SearchHandler(BaseHandler):
 
         state_results = {}
         if include_state:
-            rooms = {e.room_id for e in allowed_events}
-            for room_id in rooms:
+            for room_id in {e.room_id for e in allowed_events}:
                 state = await self.state_handler.get_current_state(room_id)
                 state_results[room_id] = list(state.values())
 
-            state_results.values()
-
         # We're now about to serialize the events. We should not make any
         # blocking calls after this. Otherwise the 'age' will be wrong
 
@@ -448,9 +454,9 @@ class SearchHandler(BaseHandler):
 
         if state_results:
             s = {}
-            for room_id, state in state_results.items():
+            for room_id, state_events in state_results.items():
                 s[room_id] = await self._event_serializer.serialize_events(
-                    state, time_now
+                    state_events, time_now
                 )
 
             rooms_cat_res["state"] = s
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index a5d67f828f..84af2dde7e 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -13,24 +13,26 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import Codes, StoreError, SynapseError
 from synapse.types import Requester
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class SetPasswordHandler(BaseHandler):
     """Handler which deals with changing user account passwords"""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self._auth_handler = hs.get_auth_handler()
         self._device_handler = hs.get_device_handler()
-        self._password_policy_handler = hs.get_password_policy_handler()
 
     async def set_password(
         self,
@@ -38,7 +40,7 @@ class SetPasswordHandler(BaseHandler):
         password_hash: str,
         logout_devices: bool,
         requester: Optional[Requester] = None,
-    ):
+    ) -> None:
         if not self.hs.config.password_localdb_enabled:
             raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
 
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index dcc85e9871..b450668f1c 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -14,21 +14,31 @@
 # limitations under the License.
 import abc
 import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Mapping, Optional
+from typing import (
+    TYPE_CHECKING,
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    Mapping,
+    Optional,
+    Set,
+)
 from urllib.parse import urlencode
 
 import attr
 from typing_extensions import NoReturn, Protocol
 
 from twisted.web.http import Request
+from twisted.web.iweb import IRequest
 
 from synapse.api.constants import LoginType
-from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http import get_request_user_agent
-from synapse.http.server import respond_with_html
+from synapse.http.server import respond_with_html, respond_with_redirect
 from synapse.http.site import SynapseRequest
-from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
+from synapse.types import Collection, JsonDict, UserID, contains_invalid_mxid_characters
 from synapse.util.async_helpers import Linearizer
 from synapse.util.stringutils import random_string
 
@@ -75,6 +85,16 @@ class SsoIdentityProvider(Protocol):
     def idp_name(self) -> str:
         """User-facing name for this provider"""
 
+    @property
+    def idp_icon(self) -> Optional[str]:
+        """Optional MXC URI for user-facing icon"""
+        return None
+
+    @property
+    def idp_brand(self) -> Optional[str]:
+        """Optional branding identifier"""
+        return None
+
     @abc.abstractmethod
     async def handle_redirect_request(
         self,
@@ -104,7 +124,7 @@ class UserAttributes:
     # enter one.
     localpart = attr.ib(type=Optional[str])
     display_name = attr.ib(type=Optional[str], default=None)
-    emails = attr.ib(type=List[str], default=attr.Factory(list))
+    emails = attr.ib(type=Collection[str], default=attr.Factory(list))
 
 
 @attr.s(slots=True)
@@ -119,7 +139,7 @@ class UsernameMappingSession:
 
     # attributes returned by the ID mapper
     display_name = attr.ib(type=Optional[str])
-    emails = attr.ib(type=List[str])
+    emails = attr.ib(type=Collection[str])
 
     # An optional dictionary of extra attributes to be provided to the client in the
     # login response.
@@ -131,6 +151,12 @@ class UsernameMappingSession:
     # expiry time for the session, in milliseconds
     expiry_time_ms = attr.ib(type=int)
 
+    # choices made by the user
+    chosen_localpart = attr.ib(type=Optional[str], default=None)
+    use_display_name = attr.ib(type=bool, default=True)
+    emails_to_use = attr.ib(type=Collection[str], default=())
+    terms_accepted_version = attr.ib(type=Optional[str], default=None)
+
 
 # the HTTP cookie used to track the mapping session id
 USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
@@ -165,6 +191,8 @@ class SsoHandler:
         # map from idp_id to SsoIdentityProvider
         self._identity_providers = {}  # type: Dict[str, SsoIdentityProvider]
 
+        self._consent_at_registration = hs.config.consent.user_consent_at_registration
+
     def register_identity_provider(self, p: SsoIdentityProvider):
         p_id = p.idp_id
         assert p_id not in self._identity_providers
@@ -230,7 +258,10 @@ class SsoHandler:
         respond_with_html(request, code, html)
 
     async def handle_redirect_request(
-        self, request: SynapseRequest, client_redirect_url: bytes,
+        self,
+        request: SynapseRequest,
+        client_redirect_url: bytes,
+        idp_id: Optional[str],
     ) -> str:
         """Handle a request to /login/sso/redirect
 
@@ -238,6 +269,7 @@ class SsoHandler:
             request: incoming HTTP request
             client_redirect_url: the URL that we should redirect the
                 client to after login.
+            idp_id: optional identity provider chosen by the client
 
         Returns:
              the URI to redirect to
@@ -247,10 +279,19 @@ class SsoHandler:
                 400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
             )
 
+        # if the client chose an IdP, use that
+        idp = None  # type: Optional[SsoIdentityProvider]
+        if idp_id:
+            idp = self._identity_providers.get(idp_id)
+            if not idp:
+                raise NotFoundError("Unknown identity provider")
+
         # if we only have one auth provider, redirect to it directly
-        if len(self._identity_providers) == 1:
-            ap = next(iter(self._identity_providers.values()))
-            return await ap.handle_redirect_request(request, client_redirect_url)
+        elif len(self._identity_providers) == 1:
+            idp = next(iter(self._identity_providers.values()))
+
+        if idp:
+            return await idp.handle_redirect_request(request, client_redirect_url)
 
         # otherwise, redirect to the IDP picker
         return "/_synapse/client/pick_idp?" + urlencode(
@@ -364,6 +405,8 @@ class SsoHandler:
                 to an additional page. (e.g. to prompt for more information)
 
         """
+        new_user = False
+
         # grab a lock while we try to find a mapping for this user. This seems...
         # optimistic, especially for implementations that end up redirecting to
         # interstitial pages.
@@ -404,9 +447,14 @@ class SsoHandler:
                     get_request_user_agent(request),
                     request.getClientIP(),
                 )
+                new_user = True
 
         await self._auth_handler.complete_sso_login(
-            user_id, request, client_redirect_url, extra_login_attributes
+            user_id,
+            request,
+            client_redirect_url,
+            extra_login_attributes,
+            new_user=new_user,
         )
 
     async def _call_attribute_mapper(
@@ -496,7 +544,7 @@ class SsoHandler:
         logger.info("Recorded registration session id %s", session_id)
 
         # Set the cookie and redirect to the username picker
-        e = RedirectException(b"/_synapse/client/pick_username")
+        e = RedirectException(b"/_synapse/client/pick_username/account_details")
         e.cookies.append(
             b"%s=%s; path=/"
             % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii"))
@@ -624,6 +672,25 @@ class SsoHandler:
         )
         respond_with_html(request, 200, html)
 
+    def get_mapping_session(self, session_id: str) -> UsernameMappingSession:
+        """Look up the given username mapping session
+
+        If it is not found, raises a SynapseError with an http code of 400
+
+        Args:
+            session_id: session to look up
+        Returns:
+            active mapping session
+        Raises:
+            SynapseError if the session is not found/has expired
+        """
+        self._expire_old_sessions()
+        session = self._username_mapping_sessions.get(session_id)
+        if session:
+            return session
+        logger.info("Couldn't find session id %s", session_id)
+        raise SynapseError(400, "unknown session")
+
     async def check_username_availability(
         self, localpart: str, session_id: str,
     ) -> bool:
@@ -640,12 +707,7 @@ class SsoHandler:
 
         # make sure that there is a valid mapping session, to stop people dictionary-
         # scanning for accounts
-
-        self._expire_old_sessions()
-        session = self._username_mapping_sessions.get(session_id)
-        if not session:
-            logger.info("Couldn't find session id %s", session_id)
-            raise SynapseError(400, "unknown session")
+        self.get_mapping_session(session_id)
 
         logger.info(
             "[session %s] Checking for availability of username %s",
@@ -662,7 +724,12 @@ class SsoHandler:
         return not user_infos
 
     async def handle_submit_username_request(
-        self, request: SynapseRequest, localpart: str, session_id: str
+        self,
+        request: SynapseRequest,
+        session_id: str,
+        localpart: str,
+        use_display_name: bool,
+        emails_to_use: Iterable[str],
     ) -> None:
         """Handle a request to the username-picker 'submit' endpoint
 
@@ -672,21 +739,90 @@ class SsoHandler:
             request: HTTP request
             localpart: localpart requested by the user
             session_id: ID of the username mapping session, extracted from a cookie
+            use_display_name: whether the user wants to use the suggested display name
+            emails_to_use: emails that the user would like to use
         """
-        self._expire_old_sessions()
-        session = self._username_mapping_sessions.get(session_id)
-        if not session:
-            logger.info("Couldn't find session id %s", session_id)
-            raise SynapseError(400, "unknown session")
+        session = self.get_mapping_session(session_id)
+
+        # update the session with the user's choices
+        session.chosen_localpart = localpart
+        session.use_display_name = use_display_name
+
+        emails_from_idp = set(session.emails)
+        filtered_emails = set()  # type: Set[str]
+
+        # we iterate through the list rather than just building a set conjunction, so
+        # that we can log attempts to use unknown addresses
+        for email in emails_to_use:
+            if email in emails_from_idp:
+                filtered_emails.add(email)
+            else:
+                logger.warning(
+                    "[session %s] ignoring user request to use unknown email address %r",
+                    session_id,
+                    email,
+                )
+        session.emails_to_use = filtered_emails
 
-        logger.info("[session %s] Registering localpart %s", session_id, localpart)
+        # we may now need to collect consent from the user, in which case, redirect
+        # to the consent-extraction-unit
+        if self._consent_at_registration:
+            redirect_url = b"/_synapse/client/new_user_consent"
+
+        # otherwise, redirect to the completion page
+        else:
+            redirect_url = b"/_synapse/client/sso_register"
+
+        respond_with_redirect(request, redirect_url)
+
+    async def handle_terms_accepted(
+        self, request: Request, session_id: str, terms_version: str
+    ):
+        """Handle a request to the new-user 'consent' endpoint
+
+        Will serve an HTTP response to the request.
+
+        Args:
+            request: HTTP request
+            session_id: ID of the username mapping session, extracted from a cookie
+            terms_version: the version of the terms which the user viewed and consented
+                to
+        """
+        logger.info(
+            "[session %s] User consented to terms version %s",
+            session_id,
+            terms_version,
+        )
+        session = self.get_mapping_session(session_id)
+        session.terms_accepted_version = terms_version
+
+        # we're done; now we can register the user
+        respond_with_redirect(request, b"/_synapse/client/sso_register")
+
+    async def register_sso_user(self, request: Request, session_id: str) -> None:
+        """Called once we have all the info we need to register a new user.
+
+        Does so and serves an HTTP response
+
+        Args:
+            request: HTTP request
+            session_id: ID of the username mapping session, extracted from a cookie
+        """
+        session = self.get_mapping_session(session_id)
+
+        logger.info(
+            "[session %s] Registering localpart %s",
+            session_id,
+            session.chosen_localpart,
+        )
 
         attributes = UserAttributes(
-            localpart=localpart,
-            display_name=session.display_name,
-            emails=session.emails,
+            localpart=session.chosen_localpart, emails=session.emails_to_use,
         )
 
+        if session.use_display_name:
+            attributes.display_name = session.display_name
+
         # the following will raise a 400 error if the username has been taken in the
         # meantime.
         user_id = await self._register_mapped_user(
@@ -697,7 +833,12 @@ class SsoHandler:
             request.getClientIP(),
         )
 
-        logger.info("[session %s] Registered userid %s", session_id, user_id)
+        logger.info(
+            "[session %s] Registered userid %s with attributes %s",
+            session_id,
+            user_id,
+            attributes,
+        )
 
         # delete the mapping session and the cookie
         del self._username_mapping_sessions[session_id]
@@ -710,11 +851,21 @@ class SsoHandler:
             path=b"/",
         )
 
+        auth_result = {}
+        if session.terms_accepted_version:
+            # TODO: make this less awful.
+            auth_result[LoginType.TERMS] = True
+
+        await self._registration_handler.post_registration_actions(
+            user_id, auth_result, access_token=None
+        )
+
         await self._auth_handler.complete_sso_login(
             user_id,
             request,
             session.client_redirect_url,
             session.extra_login_attributes,
+            new_user=True,
         )
 
     def _expire_old_sessions(self):
@@ -728,3 +879,14 @@ class SsoHandler:
         for session_id in to_expire:
             logger.info("Expiring mapping session %s", session_id)
             del self._username_mapping_sessions[session_id]
+
+
+def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
+    """Extract the session ID from the cookie
+
+    Raises a SynapseError if the cookie isn't found
+    """
+    session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
+    if not session_id:
+        raise SynapseError(code=400, msg="missing session_id")
+    return session_id.decode("ascii", errors="replace")
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index fb4f70e8e2..b3f9875358 100644
--- a/synapse/handlers/state_deltas.py
+++ b/synapse/handlers/state_deltas.py
@@ -14,15 +14,25 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Optional
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
 
 class StateDeltasHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
 
-    async def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
+    async def _get_key_change(
+        self,
+        prev_event_id: Optional[str],
+        event_id: Optional[str],
+        key_name: str,
+        public_value: str,
+    ) -> Optional[bool]:
         """Given two events check if the `key_name` field in content changed
         from not matching `public_value` to doing so.
 
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index dc62b21c06..d261d7cd4e 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -12,13 +12,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from collections import Counter
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
+
+from typing_extensions import Counter as CounterType
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.metrics import event_processing_positions
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -31,7 +37,7 @@ class StatsHandler:
     Heavily derived from UserDirectoryHandler
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()
@@ -44,7 +50,7 @@ class StatsHandler:
         self.stats_enabled = hs.config.stats_enabled
 
         # The current position in the current_state_delta stream
-        self.pos = None
+        self.pos = None  # type: Optional[int]
 
         # Guard to ensure we only process deltas one at a time
         self._is_processing = False
@@ -56,7 +62,7 @@ class StatsHandler:
             # we start populating stats
             self.clock.call_later(0, self.notify_new_event)
 
-    def notify_new_event(self):
+    def notify_new_event(self) -> None:
         """Called when there may be more deltas to process
         """
         if not self.stats_enabled or self._is_processing:
@@ -72,7 +78,7 @@ class StatsHandler:
 
         run_as_background_process("stats.notify_new_event", process)
 
-    async def _unsafe_process(self):
+    async def _unsafe_process(self) -> None:
         # If self.pos is None then means we haven't fetched it from DB
         if self.pos is None:
             self.pos = await self.store.get_stats_positions()
@@ -110,10 +116,10 @@ class StatsHandler:
             )
 
             for room_id, fields in room_count.items():
-                room_deltas.setdefault(room_id, {}).update(fields)
+                room_deltas.setdefault(room_id, Counter()).update(fields)
 
             for user_id, fields in user_count.items():
-                user_deltas.setdefault(user_id, {}).update(fields)
+                user_deltas.setdefault(user_id, Counter()).update(fields)
 
             logger.debug("room_deltas: %s", room_deltas)
             logger.debug("user_deltas: %s", user_deltas)
@@ -131,19 +137,20 @@ class StatsHandler:
 
             self.pos = max_pos
 
-    async def _handle_deltas(self, deltas):
+    async def _handle_deltas(
+        self, deltas: Iterable[JsonDict]
+    ) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]:
         """Called with the state deltas to process
 
         Returns:
-            tuple[dict[str, Counter], dict[str, counter]]
             Two dicts: the room deltas and the user deltas,
             mapping from room/user ID to changes in the various fields.
         """
 
-        room_to_stats_deltas = {}
-        user_to_stats_deltas = {}
+        room_to_stats_deltas = {}  # type: Dict[str, CounterType[str]]
+        user_to_stats_deltas = {}  # type: Dict[str, CounterType[str]]
 
-        room_to_state_updates = {}
+        room_to_state_updates = {}  # type: Dict[str, Dict[str, Any]]
 
         for delta in deltas:
             typ = delta["type"]
@@ -173,7 +180,7 @@ class StatsHandler:
                 )
                 continue
 
-            event_content = {}
+            event_content = {}  # type: JsonDict
 
             sender = None
             if event_id is not None:
@@ -257,13 +264,13 @@ class StatsHandler:
                     )
 
                     if has_changed_joinedness:
-                        delta = +1 if membership == Membership.JOIN else -1
+                        membership_delta = +1 if membership == Membership.JOIN else -1
 
                         user_to_stats_deltas.setdefault(user_id, Counter())[
                             "joined_rooms"
-                        ] += delta
+                        ] += membership_delta
 
-                        room_stats_delta["local_users_in_room"] += delta
+                        room_stats_delta["local_users_in_room"] += membership_delta
 
             elif typ == EventTypes.Create:
                 room_state["is_federatable"] = (
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index e919a8f9ed..3f0dfc7a74 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,13 +15,13 @@
 import logging
 import random
 from collections import namedtuple
-from typing import TYPE_CHECKING, List, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api.errors import AuthError, ShadowBanError, SynapseError
 from synapse.appservice import ApplicationService
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.streams import TypingStream
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
@@ -65,17 +65,17 @@ class FollowerTypingHandler:
             )
 
         # map room IDs to serial numbers
-        self._room_serials = {}
+        self._room_serials = {}  # type: Dict[str, int]
         # map room IDs to sets of users currently typing
-        self._room_typing = {}
+        self._room_typing = {}  # type: Dict[str, Set[str]]
 
-        self._member_last_federation_poke = {}
+        self._member_last_federation_poke = {}  # type: Dict[RoomMember, int]
         self.wheel_timer = WheelTimer(bucket_size=5000)
         self._latest_room_serial = 0
 
         self.clock.looping_call(self._handle_timeouts, 5000)
 
-    def _reset(self):
+    def _reset(self) -> None:
         """Reset the typing handler's data caches.
         """
         # map room IDs to serial numbers
@@ -86,7 +86,7 @@ class FollowerTypingHandler:
         self._member_last_federation_poke = {}
         self.wheel_timer = WheelTimer(bucket_size=5000)
 
-    def _handle_timeouts(self):
+    def _handle_timeouts(self) -> None:
         logger.debug("Checking for typing timeouts")
 
         now = self.clock.time_msec()
@@ -96,7 +96,7 @@ class FollowerTypingHandler:
         for member in members:
             self._handle_timeout_for_member(now, member)
 
-    def _handle_timeout_for_member(self, now: int, member: RoomMember):
+    def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
         if not self.is_typing(member):
             # Nothing to do if they're no longer typing
             return
@@ -114,10 +114,10 @@ class FollowerTypingHandler:
         # each person typing.
         self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
 
-    def is_typing(self, member):
+    def is_typing(self, member: RoomMember) -> bool:
         return member.user_id in self._room_typing.get(member.room_id, [])
 
-    async def _push_remote(self, member, typing):
+    async def _push_remote(self, member: RoomMember, typing: bool) -> None:
         if not self.federation:
             return
 
@@ -148,7 +148,7 @@ class FollowerTypingHandler:
 
     def process_replication_rows(
         self, token: int, rows: List[TypingStream.TypingStreamRow]
-    ):
+    ) -> None:
         """Should be called whenever we receive updates for typing stream.
         """
 
@@ -178,7 +178,7 @@ class FollowerTypingHandler:
 
     async def _send_changes_in_typing_to_remotes(
         self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
-    ):
+    ) -> None:
         """Process a change in typing of a room from replication, sending EDUs
         for any local users.
         """
@@ -194,12 +194,12 @@ class FollowerTypingHandler:
             if self.is_mine_id(user_id):
                 await self._push_remote(RoomMember(room_id, user_id), False)
 
-    def get_current_token(self):
+    def get_current_token(self) -> int:
         return self._latest_room_serial
 
 
 class TypingWriterHandler(FollowerTypingHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         assert hs.config.worker.writers.typing == hs.get_instance_name()
@@ -213,14 +213,15 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
-        self._member_typing_until = {}  # clock time we expect to stop
+        # clock time we expect to stop
+        self._member_typing_until = {}  # type: Dict[RoomMember, int]
 
         # caches which room_ids changed at which serials
         self._typing_stream_change_cache = StreamChangeCache(
             "TypingStreamChangeCache", self._latest_room_serial
         )
 
-    def _handle_timeout_for_member(self, now: int, member: RoomMember):
+    def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
         super()._handle_timeout_for_member(now, member)
 
         if not self.is_typing(member):
@@ -233,7 +234,9 @@ class TypingWriterHandler(FollowerTypingHandler):
             self._stopped_typing(member)
             return
 
-    async def started_typing(self, target_user, requester, room_id, timeout):
+    async def started_typing(
+        self, target_user: UserID, requester: Requester, room_id: str, timeout: int
+    ) -> None:
         target_user_id = target_user.to_string()
         auth_user_id = requester.user.to_string()
 
@@ -263,11 +266,13 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         if was_present:
             # No point sending another notification
-            return None
+            return
 
         self._push_update(member=member, typing=True)
 
-    async def stopped_typing(self, target_user, requester, room_id):
+    async def stopped_typing(
+        self, target_user: UserID, requester: Requester, room_id: str
+    ) -> None:
         target_user_id = target_user.to_string()
         auth_user_id = requester.user.to_string()
 
@@ -290,23 +295,23 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         self._stopped_typing(member)
 
-    def user_left_room(self, user, room_id):
+    def user_left_room(self, user: UserID, room_id: str) -> None:
         user_id = user.to_string()
         if self.is_mine_id(user_id):
             member = RoomMember(room_id=room_id, user_id=user_id)
             self._stopped_typing(member)
 
-    def _stopped_typing(self, member):
+    def _stopped_typing(self, member: RoomMember) -> None:
         if member.user_id not in self._room_typing.get(member.room_id, set()):
             # No point
-            return None
+            return
 
         self._member_typing_until.pop(member, None)
         self._member_last_federation_poke.pop(member, None)
 
         self._push_update(member=member, typing=False)
 
-    def _push_update(self, member, typing):
+    def _push_update(self, member: RoomMember, typing: bool) -> None:
         if self.hs.is_mine_id(member.user_id):
             # Only send updates for changes to our own users.
             run_as_background_process(
@@ -315,7 +320,7 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         self._push_update_local(member=member, typing=typing)
 
-    async def _recv_edu(self, origin, content):
+    async def _recv_edu(self, origin: str, content: JsonDict) -> None:
         room_id = content["room_id"]
         user_id = content["user_id"]
 
@@ -340,7 +345,7 @@ class TypingWriterHandler(FollowerTypingHandler):
             self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT)
             self._push_update_local(member=member, typing=content["typing"])
 
-    def _push_update_local(self, member, typing):
+    def _push_update_local(self, member: RoomMember, typing: bool) -> None:
         room_set = self._room_typing.setdefault(member.room_id, set())
         if typing:
             room_set.add(member.user_id)
@@ -386,7 +391,7 @@ class TypingWriterHandler(FollowerTypingHandler):
 
         changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
             last_id
-        )
+        )  # type: Optional[Iterable[str]]
 
         if changed_rooms is None:
             changed_rooms = self._room_serials
@@ -412,13 +417,13 @@ class TypingWriterHandler(FollowerTypingHandler):
 
     def process_replication_rows(
         self, token: int, rows: List[TypingStream.TypingStreamRow]
-    ):
+    ) -> None:
         # The writing process should never get updates from replication.
         raise Exception("Typing writer instance got typing info over replication")
 
 
 class TypingNotificationEventSource:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.clock = hs.get_clock()
         # We can't call get_typing_handler here because there's a cycle:
@@ -427,7 +432,7 @@ class TypingNotificationEventSource:
         #
         self.get_typing_handler = hs.get_typing_handler
 
-    def _make_event_for(self, room_id):
+    def _make_event_for(self, room_id: str) -> JsonDict:
         typing = self.get_typing_handler()._room_typing[room_id]
         return {
             "type": "m.typing",
@@ -462,7 +467,9 @@ class TypingNotificationEventSource:
 
             return (events, handler._latest_room_serial)
 
-    async def get_new_events(self, from_key, room_ids, **kwargs):
+    async def get_new_events(
+        self, from_key: int, room_ids: Iterable[str], **kwargs
+    ) -> Tuple[List[JsonDict], int]:
         with Measure(self.clock, "typing.get_new_events"):
             from_key = int(from_key)
             handler = self.get_typing_handler()
@@ -478,5 +485,5 @@ class TypingNotificationEventSource:
 
             return (events, handler._latest_room_serial)
 
-    def get_current_key(self):
+    def get_current_key(self) -> int:
         return self.get_typing_handler()._latest_room_serial
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index d4651c8348..8aedf5072e 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -145,10 +145,6 @@ class UserDirectoryHandler(StateDeltasHandler):
         if self.pos is None:
             self.pos = await self.store.get_user_directory_stream_pos()
 
-        # If still None then the initial background update hasn't happened yet
-        if self.pos is None:
-            return None
-
         # Loop round handling deltas until we're up to date
         while True:
             with Measure(self.clock, "user_dir_delta"):
@@ -233,6 +229,11 @@ class UserDirectoryHandler(StateDeltasHandler):
 
                     if change:  # The user joined
                         event = await self.store.get_event(event_id, allow_none=True)
+                        # It isn't expected for this event to not exist, but we
+                        # don't want the entire background process to break.
+                        if event is None:
+                            continue
+
                         profile = ProfileInfo(
                             avatar_url=event.content.get("avatar_url"),
                             display_name=event.content.get("displayname"),