summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py3
-rw-r--r--synapse/appservice/api.py1
-rw-r--r--synapse/config/__main__.py1
-rw-r--r--synapse/config/emailconfig.py3
-rw-r--r--synapse/config/jwt_config.py35
-rw-r--r--synapse/federation/transport/server.py6
-rw-r--r--synapse/handlers/appservice.py74
-rw-r--r--synapse/handlers/auth.py8
-rw-r--r--synapse/handlers/cas_handler.py3
-rw-r--r--synapse/handlers/federation.py166
-rw-r--r--synapse/handlers/typing.py3
-rw-r--r--synapse/http/additional_resource.py19
-rw-r--r--synapse/http/server.py365
-rw-r--r--synapse/logging/opentracing.py70
-rw-r--r--synapse/push/httppusher.py17
-rw-r--r--synapse/push/push_rule_evaluator.py31
-rw-r--r--synapse/python_dependencies.py2
-rw-r--r--synapse/replication/http/__init__.py3
-rw-r--r--synapse/replication/http/_base.py11
-rw-r--r--synapse/replication/slave/storage/account_data.py5
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py3
-rw-r--r--synapse/replication/slave/storage/groups.py3
-rw-r--r--synapse/replication/slave/storage/presence.py3
-rw-r--r--synapse/replication/slave/storage/push_rule.py3
-rw-r--r--synapse/replication/slave/storage/pushers.py3
-rw-r--r--synapse/replication/slave/storage/receipts.py11
-rw-r--r--synapse/replication/slave/storage/room.py3
-rw-r--r--synapse/replication/tcp/client.py2
-rw-r--r--synapse/replication/tcp/handler.py4
-rw-r--r--synapse/replication/tcp/streams/_base.py56
-rw-r--r--synapse/replication/tcp/streams/events.py2
-rw-r--r--synapse/rest/client/v1/login.py60
-rw-r--r--synapse/rest/client/v2_alpha/account.py40
-rw-r--r--synapse/rest/client/v2_alpha/register.py22
-rw-r--r--synapse/rest/consent/consent_resource.py10
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py12
-rw-r--r--synapse/rest/media/v1/config_resource.py14
-rw-r--r--synapse/rest/media/v1/download_resource.py12
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py10
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py10
-rw-r--r--synapse/rest/media/v1/thumbnailer.py3
-rw-r--r--synapse/rest/media/v1/upload_resource.py14
-rw-r--r--synapse/rest/oidc/callback_resource.py7
-rw-r--r--synapse/rest/saml2/response_resource.py4
-rw-r--r--synapse/secrets.py3
-rw-r--r--synapse/storage/data_stores/main/cache.py44
-rw-r--r--synapse/storage/data_stores/main/deviceinbox.py54
-rw-r--r--synapse/storage/data_stores/main/devices.py70
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py65
-rw-r--r--synapse/storage/data_stores/main/events.py101
-rw-r--r--synapse/storage/data_stores/main/events_worker.py11
-rw-r--r--synapse/storage/data_stores/main/group_server.py52
-rw-r--r--synapse/storage/data_stores/main/purge_events.py1
-rw-r--r--synapse/storage/data_stores/main/pusher.py108
-rw-r--r--synapse/storage/data_stores/main/room.py41
-rw-r--r--synapse/storage/data_stores/main/tags.py45
-rw-r--r--synapse/storage/data_stores/main/ui_auth.py2
-rw-r--r--synapse/storage/engines/postgres.py2
-rw-r--r--synapse/storage/types.py2
-rw-r--r--synapse/types.py2
-rw-r--r--synapse/util/threepids.py23
61 files changed, 957 insertions, 806 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 06ba6604f3..cb22508f4d 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -12,7 +12,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from typing import Optional
 
@@ -22,7 +21,6 @@ from netaddr import IPAddress
 from twisted.internet import defer
 from twisted.web.server import Request
 
-import synapse.logging.opentracing as opentracing
 import synapse.types
 from synapse import event_auth
 from synapse.api.auth_blocking import AuthBlocking
@@ -35,6 +33,7 @@ from synapse.api.errors import (
 )
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
+from synapse.logging import opentracing as opentracing
 from synapse.types import StateMap, UserID
 from synapse.util.caches import register_cache
 from synapse.util.caches.lrucache import LruCache
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index da9a5e86d4..f92bfb420b 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -98,7 +98,6 @@ class ApplicationServiceApi(SimpleHttpClient):
         if service.url is None:
             return False
         uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
-        response = None
         try:
             response = yield self.get_json(uri, {"access_token": service.hs_token})
             if response is not None:  # just an empty json object
diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py
index fca35b008c..65043d5b5b 100644
--- a/synapse/config/__main__.py
+++ b/synapse/config/__main__.py
@@ -16,6 +16,7 @@ from synapse.config._base import ConfigError
 
 if __name__ == "__main__":
     import sys
+
     from synapse.config.homeserver import HomeServerConfig
 
     action = sys.argv[1]
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index ca61214454..df08bcd1bc 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -14,7 +14,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 from __future__ import print_function
 
 # This file can't be called email.py because if it is, we cannot:
@@ -145,8 +144,8 @@ class EmailConfig(Config):
             or self.threepid_behaviour_email == ThreepidBehaviour.LOCAL
         ):
             # make sure we can import the required deps
-            import jinja2
             import bleach
+            import jinja2
 
             # prevent unused warnings
             jinja2
diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py
index a568726985..fce96b4acf 100644
--- a/synapse/config/jwt_config.py
+++ b/synapse/config/jwt_config.py
@@ -45,10 +45,37 @@ class JWTConfig(Config):
 
     def generate_config_section(self, **kwargs):
         return """\
-        # The JWT needs to contain a globally unique "sub" (subject) claim.
+        # JSON web token integration. The following settings can be used to make
+        # Synapse JSON web tokens for authentication, instead of its internal
+        # password database.
+        #
+        # Each JSON Web Token needs to contain a "sub" (subject) claim, which is
+        # used as the localpart of the mxid.
+        #
+        # Note that this is a non-standard login type and client support is
+        # expected to be non-existant.
+        #
+        # See https://github.com/matrix-org/synapse/blob/master/docs/jwt.md.
         #
         #jwt_config:
-        #   enabled: true
-        #   secret: "a secret"
-        #   algorithm: "HS256"
+            # Uncomment the following to enable authorization using JSON web
+            # tokens. Defaults to false.
+            #
+            #enabled: true
+
+            # This is either the private shared secret or the public key used to
+            # decode the contents of the JSON web token.
+            #
+            # Required if 'enabled' is true.
+            #
+            #secret: "provided-by-your-issuer"
+
+            # The algorithm used to sign the JSON web token.
+            #
+            # Supported algorithms are listed at
+            # https://pyjwt.readthedocs.io/en/latest/algorithms.html
+            #
+            # Required if 'enabled' is true.
+            #
+            #algorithm: "provided-by-your-issuer"
         """
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index af4595498c..bfb7831a02 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -361,11 +361,7 @@ class BaseFederationServlet(object):
                 continue
 
             server.register_paths(
-                method,
-                (pattern,),
-                self._wrap(code),
-                self.__class__.__name__,
-                trace=False,
+                method, (pattern,), self._wrap(code), self.__class__.__name__,
             )
 
 
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 904c96eeec..92d4c6e16c 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -48,8 +48,7 @@ class ApplicationServicesHandler(object):
         self.current_max = 0
         self.is_processing = False
 
-    @defer.inlineCallbacks
-    def notify_interested_services(self, current_id):
+    async def notify_interested_services(self, current_id):
         """Notifies (pushes) all application services interested in this event.
 
         Pushing is done asynchronously, so this method won't block for any
@@ -74,7 +73,7 @@ class ApplicationServicesHandler(object):
                     (
                         upper_bound,
                         events,
-                    ) = yield self.store.get_new_events_for_appservice(
+                    ) = await self.store.get_new_events_for_appservice(
                         self.current_max, limit
                     )
 
@@ -85,10 +84,9 @@ class ApplicationServicesHandler(object):
                     for event in events:
                         events_by_room.setdefault(event.room_id, []).append(event)
 
-                    @defer.inlineCallbacks
-                    def handle_event(event):
+                    async def handle_event(event):
                         # Gather interested services
-                        services = yield self._get_services_for_event(event)
+                        services = await self._get_services_for_event(event)
                         if len(services) == 0:
                             return  # no services need notifying
 
@@ -96,9 +94,9 @@ class ApplicationServicesHandler(object):
                         # query API for all services which match that user regex.
                         # This needs to block as these user queries need to be
                         # made BEFORE pushing the event.
-                        yield self._check_user_exists(event.sender)
+                        await self._check_user_exists(event.sender)
                         if event.type == EventTypes.Member:
-                            yield self._check_user_exists(event.state_key)
+                            await self._check_user_exists(event.state_key)
 
                         if not self.started_scheduler:
 
@@ -115,17 +113,16 @@ class ApplicationServicesHandler(object):
                             self.scheduler.submit_event_for_as(service, event)
 
                         now = self.clock.time_msec()
-                        ts = yield self.store.get_received_ts(event.event_id)
+                        ts = await self.store.get_received_ts(event.event_id)
                         synapse.metrics.event_processing_lag_by_event.labels(
                             "appservice_sender"
                         ).observe((now - ts) / 1000)
 
-                    @defer.inlineCallbacks
-                    def handle_room_events(events):
+                    async def handle_room_events(events):
                         for event in events:
-                            yield handle_event(event)
+                            await handle_event(event)
 
-                    yield make_deferred_yieldable(
+                    await make_deferred_yieldable(
                         defer.gatherResults(
                             [
                                 run_in_background(handle_room_events, evs)
@@ -135,10 +132,10 @@ class ApplicationServicesHandler(object):
                         )
                     )
 
-                    yield self.store.set_appservice_last_pos(upper_bound)
+                    await self.store.set_appservice_last_pos(upper_bound)
 
                     now = self.clock.time_msec()
-                    ts = yield self.store.get_received_ts(events[-1].event_id)
+                    ts = await self.store.get_received_ts(events[-1].event_id)
 
                     synapse.metrics.event_processing_positions.labels(
                         "appservice_sender"
@@ -161,8 +158,7 @@ class ApplicationServicesHandler(object):
             finally:
                 self.is_processing = False
 
-    @defer.inlineCallbacks
-    def query_user_exists(self, user_id):
+    async def query_user_exists(self, user_id):
         """Check if any application service knows this user_id exists.
 
         Args:
@@ -170,15 +166,14 @@ class ApplicationServicesHandler(object):
         Returns:
             True if this user exists on at least one application service.
         """
-        user_query_services = yield self._get_services_for_user(user_id=user_id)
+        user_query_services = self._get_services_for_user(user_id=user_id)
         for user_service in user_query_services:
-            is_known_user = yield self.appservice_api.query_user(user_service, user_id)
+            is_known_user = await self.appservice_api.query_user(user_service, user_id)
             if is_known_user:
                 return True
         return False
 
-    @defer.inlineCallbacks
-    def query_room_alias_exists(self, room_alias):
+    async def query_room_alias_exists(self, room_alias):
         """Check if an application service knows this room alias exists.
 
         Args:
@@ -193,19 +188,18 @@ class ApplicationServicesHandler(object):
             s for s in services if (s.is_interested_in_alias(room_alias_str))
         ]
         for alias_service in alias_query_services:
-            is_known_alias = yield self.appservice_api.query_alias(
+            is_known_alias = await self.appservice_api.query_alias(
                 alias_service, room_alias_str
             )
             if is_known_alias:
                 # the alias exists now so don't query more ASes.
-                result = yield self.store.get_association_from_room_alias(room_alias)
+                result = await self.store.get_association_from_room_alias(room_alias)
                 return result
 
-    @defer.inlineCallbacks
-    def query_3pe(self, kind, protocol, fields):
-        services = yield self._get_services_for_3pn(protocol)
+    async def query_3pe(self, kind, protocol, fields):
+        services = self._get_services_for_3pn(protocol)
 
-        results = yield make_deferred_yieldable(
+        results = await make_deferred_yieldable(
             defer.DeferredList(
                 [
                     run_in_background(
@@ -224,8 +218,7 @@ class ApplicationServicesHandler(object):
 
         return ret
 
-    @defer.inlineCallbacks
-    def get_3pe_protocols(self, only_protocol=None):
+    async def get_3pe_protocols(self, only_protocol=None):
         services = self.store.get_app_services()
         protocols = {}
 
@@ -238,7 +231,7 @@ class ApplicationServicesHandler(object):
                 if p not in protocols:
                     protocols[p] = []
 
-                info = yield self.appservice_api.get_3pe_protocol(s, p)
+                info = await self.appservice_api.get_3pe_protocol(s, p)
 
                 if info is not None:
                     protocols[p].append(info)
@@ -263,8 +256,7 @@ class ApplicationServicesHandler(object):
 
         return protocols
 
-    @defer.inlineCallbacks
-    def _get_services_for_event(self, event):
+    async def _get_services_for_event(self, event):
         """Retrieve a list of application services interested in this event.
 
         Args:
@@ -280,7 +272,7 @@ class ApplicationServicesHandler(object):
         # inside of a list comprehension anymore.
         interested_list = []
         for s in services:
-            if (yield s.is_interested(event, self.store)):
+            if await s.is_interested(event, self.store):
                 interested_list.append(s)
 
         return interested_list
@@ -288,21 +280,20 @@ class ApplicationServicesHandler(object):
     def _get_services_for_user(self, user_id):
         services = self.store.get_app_services()
         interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
-        return defer.succeed(interested_list)
+        return interested_list
 
     def _get_services_for_3pn(self, protocol):
         services = self.store.get_app_services()
         interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
-        return defer.succeed(interested_list)
+        return interested_list
 
-    @defer.inlineCallbacks
-    def _is_unknown_user(self, user_id):
+    async def _is_unknown_user(self, user_id):
         if not self.is_mine_id(user_id):
             # we don't know if they are unknown or not since it isn't one of our
             # users. We can't poke ASes.
             return False
 
-        user_info = yield self.store.get_user_by_id(user_id)
+        user_info = await self.store.get_user_by_id(user_id)
         if user_info:
             return False
 
@@ -311,10 +302,9 @@ class ApplicationServicesHandler(object):
         service_list = [s for s in services if s.sender == user_id]
         return len(service_list) == 0
 
-    @defer.inlineCallbacks
-    def _check_user_exists(self, user_id):
-        unknown_user = yield self._is_unknown_user(user_id)
+    async def _check_user_exists(self, user_id):
+        unknown_user = await self._is_unknown_user(user_id)
         if unknown_user:
-            exists = yield self.query_user_exists(user_id)
+            exists = await self.query_user_exists(user_id)
             return exists
         return True
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c3f86e7414..a162392e4c 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -13,7 +13,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 import time
 import unicodedata
@@ -24,7 +23,6 @@ import attr
 import bcrypt  # type: ignore[import]
 import pymacaroons
 
-import synapse.util.stringutils as stringutils
 from synapse.api.constants import LoginType
 from synapse.api.errors import (
     AuthError,
@@ -45,6 +43,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.module_api import ModuleApi
 from synapse.push.mailer import load_jinja2_templates
 from synapse.types import Requester, UserID
+from synapse.util import stringutils as stringutils
+from synapse.util.threepids import canonicalise_email
 
 from ._base import BaseHandler
 
@@ -928,7 +928,7 @@ class AuthHandler(BaseHandler):
         # for the presence of an email address during password reset was
         # case sensitive).
         if medium == "email":
-            address = address.lower()
+            address = canonicalise_email(address)
 
         await self.store.user_add_threepid(
             user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
@@ -956,7 +956,7 @@ class AuthHandler(BaseHandler):
 
         # 'Canonicalise' email addresses as per above
         if medium == "email":
-            address = address.lower()
+            address = canonicalise_email(address)
 
         identity_handler = self.hs.get_handlers().identity_handler
         result = await identity_handler.try_unbind_threepid(
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 76f213723a..d79ffefdb5 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -12,11 +12,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 import urllib
-import xml.etree.ElementTree as ET
 from typing import Dict, Optional, Tuple
+from xml.etree import ElementTree as ET
 
 from twisted.web.client import PartialDownloadError
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4dbd8e1d98..b5aaa244dd 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,8 +19,9 @@
 
 import itertools
 import logging
+from collections import Container
 from http import HTTPStatus
-from typing import Dict, Iterable, List, Optional, Sequence, Tuple
+from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
 
 import attr
 from signedjson.key import decode_verify_key_bytes
@@ -742,6 +743,9 @@ class FederationHandler(BaseHandler):
                 # device and recognize the algorithm then we can work out the
                 # exact key to expect. Otherwise check it matches any key we
                 # have for that device.
+
+                current_keys = []  # type: Container[str]
+
                 if device:
                     keys = device.get("keys", {}).get("keys", {})
 
@@ -758,15 +762,15 @@ class FederationHandler(BaseHandler):
                         current_keys = keys.values()
                 elif device_id:
                     # We don't have any keys for the device ID.
-                    current_keys = []
+                    pass
                 else:
                     # The event didn't include a device ID, so we just look for
                     # keys across all devices.
-                    current_keys = (
+                    current_keys = [
                         key
                         for device in cached_devices
                         for key in device.get("keys", {}).get("keys", {}).values()
-                    )
+                    ]
 
                 # We now check that the sender key matches (one of) the expected
                 # keys.
@@ -1011,7 +1015,7 @@ class FederationHandler(BaseHandler):
                 if e_type == EventTypes.Member and event.membership == Membership.JOIN
             ]
 
-            joined_domains = {}
+            joined_domains = {}  # type: Dict[str, int]
             for u, d in joined_users:
                 try:
                     dom = get_domain_from_id(u)
@@ -1277,14 +1281,15 @@ class FederationHandler(BaseHandler):
         try:
             # Try the host we successfully got a response to /make_join/
             # request first.
+            host_list = list(target_hosts)
             try:
-                target_hosts.remove(origin)
-                target_hosts.insert(0, origin)
+                host_list.remove(origin)
+                host_list.insert(0, origin)
             except ValueError:
                 pass
 
             ret = await self.federation_client.send_join(
-                target_hosts, event, room_version_obj
+                host_list, event, room_version_obj
             )
 
             origin = ret["origin"]
@@ -1584,13 +1589,14 @@ class FederationHandler(BaseHandler):
 
         # Try the host that we succesfully called /make_leave/ on first for
         # the /send_leave/ request.
+        host_list = list(target_hosts)
         try:
-            target_hosts.remove(origin)
-            target_hosts.insert(0, origin)
+            host_list.remove(origin)
+            host_list.insert(0, origin)
         except ValueError:
             pass
 
-        await self.federation_client.send_leave(target_hosts, event)
+        await self.federation_client.send_leave(host_list, event)
 
         context = await self.state_handler.compute_event_context(event)
         stream_id = await self.persist_events_and_notify([(event, context)])
@@ -1604,7 +1610,7 @@ class FederationHandler(BaseHandler):
         user_id: str,
         membership: str,
         content: JsonDict = {},
-        params: Optional[Dict[str, str]] = None,
+        params: Optional[Dict[str, Union[str, Iterable[str]]]] = None,
     ) -> Tuple[str, EventBase, RoomVersion]:
         (
             origin,
@@ -2018,8 +2024,8 @@ class FederationHandler(BaseHandler):
             auth_events_ids = await self.auth.compute_auth_events(
                 event, prev_state_ids, for_verification=True
             )
-            auth_events = await self.store.get_events(auth_events_ids)
-            auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
+            auth_events_x = await self.store.get_events(auth_events_ids)
+            auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
 
         # This is a hack to fix some old rooms where the initial join event
         # didn't reference the create event in its auth events.
@@ -2055,76 +2061,67 @@ class FederationHandler(BaseHandler):
         # For new (non-backfilled and non-outlier) events we check if the event
         # passes auth based on the current state. If it doesn't then we
         # "soft-fail" the event.
-        do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier()
-        if do_soft_fail_check:
-            extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
-
-            extrem_ids = set(extrem_ids)
-            prev_event_ids = set(event.prev_event_ids())
-
-            if extrem_ids == prev_event_ids:
-                # If they're the same then the current state is the same as the
-                # state at the event, so no point rechecking auth for soft fail.
-                do_soft_fail_check = False
-
-        if do_soft_fail_check:
-            room_version = await self.store.get_room_version_id(event.room_id)
-            room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
-
-            # Calculate the "current state".
-            if state is not None:
-                # If we're explicitly given the state then we won't have all the
-                # prev events, and so we have a gap in the graph. In this case
-                # we want to be a little careful as we might have been down for
-                # a while and have an incorrect view of the current state,
-                # however we still want to do checks as gaps are easy to
-                # maliciously manufacture.
-                #
-                # So we use a "current state" that is actually a state
-                # resolution across the current forward extremities and the
-                # given state at the event. This should correctly handle cases
-                # like bans, especially with state res v2.
+        if backfilled or event.internal_metadata.is_outlier():
+            return
 
-                state_sets = await self.state_store.get_state_groups(
-                    event.room_id, extrem_ids
-                )
-                state_sets = list(state_sets.values())
-                state_sets.append(state)
-                current_state_ids = await self.state_handler.resolve_events(
-                    room_version, state_sets, event
-                )
-                current_state_ids = {
-                    k: e.event_id for k, e in current_state_ids.items()
-                }
-            else:
-                current_state_ids = await self.state_handler.get_current_state_ids(
-                    event.room_id, latest_event_ids=extrem_ids
-                )
+        extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
+        extrem_ids = set(extrem_ids)
+        prev_event_ids = set(event.prev_event_ids())
 
-            logger.debug(
-                "Doing soft-fail check for %s: state %s",
-                event.event_id,
-                current_state_ids,
+        if extrem_ids == prev_event_ids:
+            # If they're the same then the current state is the same as the
+            # state at the event, so no point rechecking auth for soft fail.
+            return
+
+        room_version = await self.store.get_room_version_id(event.room_id)
+        room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+
+        # Calculate the "current state".
+        if state is not None:
+            # If we're explicitly given the state then we won't have all the
+            # prev events, and so we have a gap in the graph. In this case
+            # we want to be a little careful as we might have been down for
+            # a while and have an incorrect view of the current state,
+            # however we still want to do checks as gaps are easy to
+            # maliciously manufacture.
+            #
+            # So we use a "current state" that is actually a state
+            # resolution across the current forward extremities and the
+            # given state at the event. This should correctly handle cases
+            # like bans, especially with state res v2.
+
+            state_sets = await self.state_store.get_state_groups(
+                event.room_id, extrem_ids
+            )
+            state_sets = list(state_sets.values())
+            state_sets.append(state)
+            current_state_ids = await self.state_handler.resolve_events(
+                room_version, state_sets, event
+            )
+            current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
+        else:
+            current_state_ids = await self.state_handler.get_current_state_ids(
+                event.room_id, latest_event_ids=extrem_ids
             )
 
-            # Now check if event pass auth against said current state
-            auth_types = auth_types_for_event(event)
-            current_state_ids = [
-                e for k, e in current_state_ids.items() if k in auth_types
-            ]
+        logger.debug(
+            "Doing soft-fail check for %s: state %s", event.event_id, current_state_ids,
+        )
 
-            current_auth_events = await self.store.get_events(current_state_ids)
-            current_auth_events = {
-                (e.type, e.state_key): e for e in current_auth_events.values()
-            }
+        # Now check if event pass auth against said current state
+        auth_types = auth_types_for_event(event)
+        current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
 
-            try:
-                event_auth.check(
-                    room_version_obj, event, auth_events=current_auth_events
-                )
-            except AuthError as e:
-                logger.warning("Soft-failing %r because %s", event, e)
-                event.internal_metadata.soft_failed = True
+        current_auth_events = await self.store.get_events(current_state_ids)
+        current_auth_events = {
+            (e.type, e.state_key): e for e in current_auth_events.values()
+        }
+
+        try:
+            event_auth.check(room_version_obj, event, auth_events=current_auth_events)
+        except AuthError as e:
+            logger.warning("Soft-failing %r because %s", event, e)
+            event.internal_metadata.soft_failed = True
 
     async def on_query_auth(
         self, origin, event_id, room_id, remote_auth_chain, rejects, missing
@@ -2293,10 +2290,10 @@ class FederationHandler(BaseHandler):
                     remote_auth_chain = await self.federation_client.get_event_auth(
                         origin, event.room_id, event.event_id
                     )
-                except RequestSendFailed as e:
+                except RequestSendFailed as e1:
                     # The other side isn't around or doesn't implement the
                     # endpoint, so lets just bail out.
-                    logger.info("Failed to get event auth from remote: %s", e)
+                    logger.info("Failed to get event auth from remote: %s", e1)
                     return context
 
                 seen_remotes = await self.store.have_seen_events(
@@ -2774,7 +2771,8 @@ class FederationHandler(BaseHandler):
 
         logger.debug("Checking auth on event %r", event.content)
 
-        last_exception = None
+        last_exception = None  # type: Optional[Exception]
+
         # for each public key in the 3pid invite event
         for public_key_object in self.hs.get_auth().get_public_keys(invite_event):
             try:
@@ -2828,6 +2826,12 @@ class FederationHandler(BaseHandler):
                         return
             except Exception as e:
                 last_exception = e
+
+        if last_exception is None:
+            # we can only get here if get_public_keys() returned an empty list
+            # TODO: make this better
+            raise RuntimeError("no public key in invite event")
+
         raise last_exception
 
     async def _check_key_revocation(self, public_key, url):
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 6c7abaa578..879c4c07c6 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -294,6 +294,9 @@ class TypingHandler(object):
         rows.sort()
 
         limited = False
+        # We, unusually, use a strict limit here as we have all the rows in
+        # memory rather than pulling them out of the database with a `LIMIT ?`
+        # clause.
         if len(rows) > limit:
             rows = rows[:limit]
             current_id = rows[-1][0]
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
index 096619a8c2..479746c9c5 100644
--- a/synapse/http/additional_resource.py
+++ b/synapse/http/additional_resource.py
@@ -13,13 +13,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
+from synapse.http.server import DirectServeJsonResource
 
-from synapse.http.server import wrap_json_request_handler
 
-
-class AdditionalResource(Resource):
+class AdditionalResource(DirectServeJsonResource):
     """Resource wrapper for additional_resources
 
     If the user has configured additional_resources, we need to wrap the
@@ -41,16 +38,10 @@ class AdditionalResource(Resource):
             handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
                 function to be called to handle the request.
         """
-        Resource.__init__(self)
+        super().__init__()
         self._handler = handler
 
-        # required by the request_handler wrapper
-        self.clock = hs.get_clock()
-
-    def render(self, request):
-        self._async_render(request)
-        return NOT_DONE_YET
-
-    @wrap_json_request_handler
     def _async_render(self, request):
+        # Cheekily pass the result straight through, so we don't need to worry
+        # if its an awaitable or not.
         return self._handler(request)
diff --git a/synapse/http/server.py b/synapse/http/server.py
index d192de7923..2b35f86066 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import abc
 import collections
 import html
 import logging
@@ -21,7 +22,7 @@ import types
 import urllib
 from http import HTTPStatus
 from io import BytesIO
-from typing import Awaitable, Callable, TypeVar, Union
+from typing import Any, Callable, Dict, Tuple, Union
 
 import jinja2
 from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
@@ -62,99 +63,43 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
 """
 
 
-def wrap_json_request_handler(h):
-    """Wraps a request handler method with exception handling.
-
-    Also does the wrapping with request.processing as per wrap_async_request_handler.
-
-    The handler method must have a signature of "handle_foo(self, request)",
-    where "request" must be a SynapseRequest.
-
-    The handler must return a deferred or a coroutine. If the deferred succeeds
-    we assume that a response has been sent. If the deferred fails with a SynapseError we use
-    it to send a JSON response with the appropriate HTTP reponse code. If the
-    deferred fails with any other type of error we send a 500 reponse.
+def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
+    """Sends a JSON error response to clients.
     """
 
-    async def wrapped_request_handler(self, request):
-        try:
-            await h(self, request)
-        except SynapseError as e:
-            code = e.code
-            logger.info("%s SynapseError: %s - %s", request, code, e.msg)
-
-            # Only respond with an error response if we haven't already started
-            # writing, otherwise lets just kill the connection
-            if request.startedWriting:
-                if request.transport:
-                    try:
-                        request.transport.abortConnection()
-                    except Exception:
-                        # abortConnection throws if the connection is already closed
-                        pass
-            else:
-                respond_with_json(
-                    request,
-                    code,
-                    e.error_dict(),
-                    send_cors=True,
-                    pretty_print=_request_user_agent_is_curl(request),
-                )
-
-        except Exception:
-            # failure.Failure() fishes the original Failure out
-            # of our stack, and thus gives us a sensible stack
-            # trace.
-            f = failure.Failure()
-            logger.error(
-                "Failed handle request via %r: %r",
-                request.request_metrics.name,
-                request,
-                exc_info=(f.type, f.value, f.getTracebackObject()),
-            )
-            # Only respond with an error response if we haven't already started
-            # writing, otherwise lets just kill the connection
-            if request.startedWriting:
-                if request.transport:
-                    try:
-                        request.transport.abortConnection()
-                    except Exception:
-                        # abortConnection throws if the connection is already closed
-                        pass
-            else:
-                respond_with_json(
-                    request,
-                    500,
-                    {"error": "Internal server error", "errcode": Codes.UNKNOWN},
-                    send_cors=True,
-                    pretty_print=_request_user_agent_is_curl(request),
-                )
-
-    return wrap_async_request_handler(wrapped_request_handler)
-
-
-TV = TypeVar("TV")
-
-
-def wrap_html_request_handler(
-    h: Callable[[TV, SynapseRequest], Awaitable]
-) -> Callable[[TV, SynapseRequest], Awaitable[None]]:
-    """Wraps a request handler method with exception handling.
+    if f.check(SynapseError):
+        error_code = f.value.code
+        error_dict = f.value.error_dict()
 
-    Also does the wrapping with request.processing as per wrap_async_request_handler.
-
-    The handler method must have a signature of "handle_foo(self, request)",
-    where "request" must be a SynapseRequest.
-    """
+        logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
+    else:
+        error_code = 500
+        error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
 
-    async def wrapped_request_handler(self, request):
-        try:
-            await h(self, request)
-        except Exception:
-            f = failure.Failure()
-            return_html_error(f, request, HTML_ERROR_TEMPLATE)
+        logger.error(
+            "Failed handle request via %r: %r",
+            request.request_metrics.name,
+            request,
+            exc_info=(f.type, f.value, f.getTracebackObject()),
+        )
 
-    return wrap_async_request_handler(wrapped_request_handler)
+    # Only respond with an error response if we haven't already started writing,
+    # otherwise lets just kill the connection
+    if request.startedWriting:
+        if request.transport:
+            try:
+                request.transport.abortConnection()
+            except Exception:
+                # abortConnection throws if the connection is already closed
+                pass
+    else:
+        respond_with_json(
+            request,
+            error_code,
+            error_dict,
+            send_cors=True,
+            pretty_print=_request_user_agent_is_curl(request),
+        )
 
 
 def return_html_error(
@@ -249,7 +194,113 @@ class HttpServer(object):
         pass
 
 
-class JsonResource(HttpServer, resource.Resource):
+class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
+    """Base class for resources that have async handlers.
+
+    Sub classes can either implement `_async_render_<METHOD>` to handle
+    requests by method, or override `_async_render` to handle all requests.
+
+    Args:
+        extract_context: Whether to attempt to extract the opentracing
+            context from the request the servlet is handling.
+    """
+
+    def __init__(self, extract_context=False):
+        super().__init__()
+
+        self._extract_context = extract_context
+
+    def render(self, request):
+        """ This gets called by twisted every time someone sends us a request.
+        """
+        defer.ensureDeferred(self._async_render_wrapper(request))
+        return NOT_DONE_YET
+
+    @wrap_async_request_handler
+    async def _async_render_wrapper(self, request):
+        """This is a wrapper that delegates to `_async_render` and handles
+        exceptions, return values, metrics, etc.
+        """
+        try:
+            request.request_metrics.name = self.__class__.__name__
+
+            with trace_servlet(request, self._extract_context):
+                callback_return = await self._async_render(request)
+
+                if callback_return is not None:
+                    code, response = callback_return
+                    self._send_response(request, code, response)
+        except Exception:
+            # failure.Failure() fishes the original Failure out
+            # of our stack, and thus gives us a sensible stack
+            # trace.
+            f = failure.Failure()
+            self._send_error_response(f, request)
+
+    async def _async_render(self, request):
+        """Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
+        no appropriate method exists. Can be overriden in sub classes for
+        different routing.
+        """
+
+        method_handler = getattr(
+            self, "_async_render_%s" % (request.method.decode("ascii"),), None
+        )
+        if method_handler:
+            raw_callback_return = method_handler(request)
+
+            # Is it synchronous? We'll allow this for now.
+            if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
+                callback_return = await raw_callback_return
+            else:
+                callback_return = raw_callback_return
+
+            return callback_return
+
+        _unrecognised_request_handler(request)
+
+    @abc.abstractmethod
+    def _send_response(
+        self, request: SynapseRequest, code: int, response_object: Any,
+    ) -> None:
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def _send_error_response(
+        self, f: failure.Failure, request: SynapseRequest,
+    ) -> None:
+        raise NotImplementedError()
+
+
+class DirectServeJsonResource(_AsyncResource):
+    """A resource that will call `self._async_on_<METHOD>` on new requests,
+    formatting responses and errors as JSON.
+    """
+
+    def _send_response(
+        self, request, code, response_object,
+    ):
+        """Implements _AsyncResource._send_response
+        """
+        # TODO: Only enable CORS for the requests that need it.
+        respond_with_json(
+            request,
+            code,
+            response_object,
+            send_cors=True,
+            pretty_print=_request_user_agent_is_curl(request),
+            canonical_json=self.canonical_json,
+        )
+
+    def _send_error_response(
+        self, f: failure.Failure, request: SynapseRequest,
+    ) -> None:
+        """Implements _AsyncResource._send_error_response
+        """
+        return_json_error(f, request)
+
+
+class JsonResource(DirectServeJsonResource):
     """ This implements the HttpServer interface and provides JSON support for
     Resources.
 
@@ -269,17 +320,15 @@ class JsonResource(HttpServer, resource.Resource):
         "_PathEntry", ["pattern", "callback", "servlet_classname"]
     )
 
-    def __init__(self, hs, canonical_json=True):
-        resource.Resource.__init__(self)
+    def __init__(self, hs, canonical_json=True, extract_context=False):
+        super().__init__(extract_context)
 
         self.canonical_json = canonical_json
         self.clock = hs.get_clock()
         self.path_regexs = {}
         self.hs = hs
 
-    def register_paths(
-        self, method, path_patterns, callback, servlet_classname, trace=True
-    ):
+    def register_paths(self, method, path_patterns, callback, servlet_classname):
         """
         Registers a request handler against a regular expression. Later request URLs are
         checked against these regular expressions in order to identify an appropriate
@@ -295,37 +344,42 @@ class JsonResource(HttpServer, resource.Resource):
 
             servlet_classname (str): The name of the handler to be used in prometheus
                 and opentracing logs.
-
-            trace (bool): Whether we should start a span to trace the servlet.
         """
         method = method.encode("utf-8")  # method is bytes on py3
 
-        if trace:
-            # We don't extract the context from the servlet because we can't
-            # trust the sender
-            callback = trace_servlet(servlet_classname)(callback)
-
         for path_pattern in path_patterns:
             logger.debug("Registering for %s %s", method, path_pattern.pattern)
             self.path_regexs.setdefault(method, []).append(
                 self._PathEntry(path_pattern, callback, servlet_classname)
             )
 
-    def render(self, request):
-        """ This gets called by twisted every time someone sends us a request.
+    def _get_handler_for_request(
+        self, request: SynapseRequest
+    ) -> Tuple[Callable, str, Dict[str, str]]:
+        """Finds a callback method to handle the given request.
+
+        Returns:
+            A tuple of the callback to use, the name of the servlet, and the
+            key word arguments to pass to the callback
         """
-        defer.ensureDeferred(self._async_render(request))
-        return NOT_DONE_YET
+        request_path = request.path.decode("ascii")
+
+        # Loop through all the registered callbacks to check if the method
+        # and path regex match
+        for path_entry in self.path_regexs.get(request.method, []):
+            m = path_entry.pattern.match(request_path)
+            if m:
+                # We found a match!
+                return path_entry.callback, path_entry.servlet_classname, m.groupdict()
+
+        # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
+        return _unrecognised_request_handler, "unrecognised_request_handler", {}
 
-    @wrap_json_request_handler
     async def _async_render(self, request):
-        """ This gets called from render() every time someone sends us a request.
-            This checks if anyone has registered a callback for that method and
-            path.
-        """
         callback, servlet_classname, group_dict = self._get_handler_for_request(request)
 
-        # Make sure we have a name for this handler in prometheus.
+        # Make sure we have an appopriate name for this handler in prometheus
+        # (rather than the default of JsonResource).
         request.request_metrics.name = servlet_classname
 
         # Now trigger the callback. If it returns a response, we send it
@@ -338,81 +392,42 @@ class JsonResource(HttpServer, resource.Resource):
             }
         )
 
-        callback_return = callback(request, **kwargs)
+        raw_callback_return = callback(request, **kwargs)
 
         # Is it synchronous? We'll allow this for now.
-        if isinstance(callback_return, (defer.Deferred, types.CoroutineType)):
-            callback_return = await callback_return
+        if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
+            callback_return = await raw_callback_return
+        else:
+            callback_return = raw_callback_return
 
-        if callback_return is not None:
-            code, response = callback_return
-            self._send_response(request, code, response)
+        return callback_return
 
-    def _get_handler_for_request(self, request):
-        """Finds a callback method to handle the given request
 
-        Args:
-            request (twisted.web.http.Request):
+class DirectServeHtmlResource(_AsyncResource):
+    """A resource that will call `self._async_on_<METHOD>` on new requests,
+    formatting responses and errors as HTML.
+    """
 
-        Returns:
-            Tuple[Callable, str, dict[unicode, unicode]]: callback method, the
-                label to use for that method in prometheus metrics, and the
-                dict mapping keys to path components as specified in the
-                handler's path match regexp.
-
-                The callback will normally be a method registered via
-                register_paths, so will return (possibly via Deferred) either
-                None, or a tuple of (http code, response body).
-        """
-        request_path = request.path.decode("ascii")
-
-        # Loop through all the registered callbacks to check if the method
-        # and path regex match
-        for path_entry in self.path_regexs.get(request.method, []):
-            m = path_entry.pattern.match(request_path)
-            if m:
-                # We found a match!
-                return path_entry.callback, path_entry.servlet_classname, m.groupdict()
-
-        # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
-        return _unrecognised_request_handler, "unrecognised_request_handler", {}
+    # The error template to use for this resource
+    ERROR_TEMPLATE = HTML_ERROR_TEMPLATE
 
     def _send_response(
-        self, request, code, response_json_object, response_code_message=None
+        self, request: SynapseRequest, code: int, response_object: Any,
     ):
-        # TODO: Only enable CORS for the requests that need it.
-        respond_with_json(
-            request,
-            code,
-            response_json_object,
-            send_cors=True,
-            response_code_message=response_code_message,
-            pretty_print=_request_user_agent_is_curl(request),
-            canonical_json=self.canonical_json,
-        )
-
-
-class DirectServeResource(resource.Resource):
-    def render(self, request):
+        """Implements _AsyncResource._send_response
         """
-        Render the request, using an asynchronous render handler if it exists.
-        """
-        async_render_callback_name = "_async_render_" + request.method.decode("ascii")
-
-        # Try and get the async renderer
-        callback = getattr(self, async_render_callback_name, None)
+        # We expect to get bytes for us to write
+        assert isinstance(response_object, bytes)
+        html_bytes = response_object
 
-        # No async renderer for this request method.
-        if not callback:
-            return super().render(request)
+        respond_with_html_bytes(request, 200, html_bytes)
 
-        resp = trace_servlet(self.__class__.__name__)(callback)(request)
-
-        # If it's a coroutine, turn it into a Deferred
-        if isinstance(resp, types.CoroutineType):
-            defer.ensureDeferred(resp)
-
-        return NOT_DONE_YET
+    def _send_error_response(
+        self, f: failure.Failure, request: SynapseRequest,
+    ) -> None:
+        """Implements _AsyncResource._send_error_response
+        """
+        return_html_error(f, request, self.ERROR_TEMPLATE)
 
 
 class StaticResource(File):
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 73bef5e5ca..c6c0e623c1 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -164,12 +164,10 @@ Gotchas
   than one caller? Will all of those calling functions have be in a context
   with an active span?
 """
-
 import contextlib
 import inspect
 import logging
 import re
-import types
 from functools import wraps
 from typing import TYPE_CHECKING, Dict, Optional, Type
 
@@ -181,6 +179,7 @@ from twisted.internet import defer
 from synapse.config import ConfigError
 
 if TYPE_CHECKING:
+    from synapse.http.site import SynapseRequest
     from synapse.server import HomeServer
 
 # Helper class
@@ -227,6 +226,7 @@ except ImportError:
     tags = _DummyTagNames
 try:
     from jaeger_client import Config as JaegerConfig
+
     from synapse.logging.scopecontextmanager import LogContextScopeManager
 except ImportError:
     JaegerConfig = None  # type: ignore
@@ -793,48 +793,42 @@ def tag_args(func):
     return _tag_args_inner
 
 
-def trace_servlet(servlet_name, extract_context=False):
-    """Decorator which traces a serlet. It starts a span with some servlet specific
-    tags such as the servlet_name and request information
+@contextlib.contextmanager
+def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
+    """Returns a context manager which traces a request. It starts a span
+    with some servlet specific tags such as the request metrics name and
+    request information.
 
     Args:
-        servlet_name (str): The name to be used for the span's operation_name
-        extract_context (bool): Whether to attempt to extract the opentracing
+        request
+        extract_context: Whether to attempt to extract the opentracing
             context from the request the servlet is handling.
-
     """
 
-    def _trace_servlet_inner_1(func):
-        if not opentracing:
-            return func
-
-        @wraps(func)
-        async def _trace_servlet_inner(request, *args, **kwargs):
-            request_tags = {
-                "request_id": request.get_request_id(),
-                tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
-                tags.HTTP_METHOD: request.get_method(),
-                tags.HTTP_URL: request.get_redacted_uri(),
-                tags.PEER_HOST_IPV6: request.getClientIP(),
-            }
-
-            if extract_context:
-                scope = start_active_span_from_request(
-                    request, servlet_name, tags=request_tags
-                )
-            else:
-                scope = start_active_span(servlet_name, tags=request_tags)
-
-            with scope:
-                result = func(request, *args, **kwargs)
+    if opentracing is None:
+        yield
+        return
 
-                if not isinstance(result, (types.CoroutineType, defer.Deferred)):
-                    # Some servlets aren't async and just return results
-                    # directly, so we handle that here.
-                    return result
+    request_tags = {
+        "request_id": request.get_request_id(),
+        tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
+        tags.HTTP_METHOD: request.get_method(),
+        tags.HTTP_URL: request.get_redacted_uri(),
+        tags.PEER_HOST_IPV6: request.getClientIP(),
+    }
 
-                return await result
+    request_name = request.request_metrics.name
+    if extract_context:
+        scope = start_active_span_from_request(request, request_name, tags=request_tags)
+    else:
+        scope = start_active_span(request_name, tags=request_tags)
 
-        return _trace_servlet_inner
+    with scope:
+        try:
+            yield
+        finally:
+            # We set the operation name again in case its changed (which happens
+            # with JsonResource).
+            scope.span.set_operation_name(request.request_metrics.name)
 
-    return _trace_servlet_inner_1
+            scope.span.set_tag("request_tag", request.request_metrics.start_context.tag)
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index ed60dbc1bf..2fac07593b 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -20,6 +20,7 @@ from prometheus_client import Counter
 from twisted.internet import defer
 from twisted.internet.error import AlreadyCalled, AlreadyCancelled
 
+from synapse.api.constants import EventTypes
 from synapse.logging import opentracing
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.push import PusherConfigException
@@ -305,12 +306,23 @@ class HttpPusher(object):
 
     @defer.inlineCallbacks
     def _build_notification_dict(self, event, tweaks, badge):
+        priority = "low"
+        if (
+            event.type == EventTypes.Encrypted
+            or tweaks.get("highlight")
+            or tweaks.get("sound")
+        ):
+            # HACK send our push as high priority only if it generates a sound, highlight
+            #  or may do so (i.e. is encrypted so has unknown effects).
+            priority = "high"
+
         if self.data.get("format") == "event_id_only":
             d = {
                 "notification": {
                     "event_id": event.event_id,
                     "room_id": event.room_id,
                     "counts": {"unread": badge},
+                    "prio": priority,
                     "devices": [
                         {
                             "app_id": self.app_id,
@@ -334,9 +346,8 @@ class HttpPusher(object):
                 "room_id": event.room_id,
                 "type": event.type,
                 "sender": event.user_id,
-                "counts": {  # -- we don't mark messages as read yet so
-                    # we have no way of knowing
-                    # Just set the badge to 1 until we have read receipts
+                "prio": priority,
+                "counts": {
                     "unread": badge,
                     # 'missed_calls': 2
                 },
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 8e0d3a416d..2d79ada189 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -16,7 +16,7 @@
 
 import logging
 import re
-from typing import Pattern
+from typing import Any, Dict, List, Pattern, Union
 
 from synapse.events import EventBase
 from synapse.types import UserID
@@ -72,13 +72,36 @@ def _test_ineq_condition(condition, number):
         return False
 
 
-def tweaks_for_actions(actions):
+def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]:
+    """
+    Converts a list of actions into a `tweaks` dict (which can then be passed to
+        the push gateway).
+
+    This function ignores all actions other than `set_tweak` actions, and treats
+    absent `value`s as `True`, which agrees with the only spec-defined treatment
+    of absent `value`s (namely, for `highlight` tweaks).
+
+    Args:
+        actions: list of actions
+            e.g. [
+                {"set_tweak": "a", "value": "AAA"},
+                {"set_tweak": "b", "value": "BBB"},
+                {"set_tweak": "highlight"},
+                "notify"
+            ]
+
+    Returns:
+        dictionary of tweaks for those actions
+            e.g. {"a": "AAA", "b": "BBB", "highlight": True}
+    """
     tweaks = {}
     for a in actions:
         if not isinstance(a, dict):
             continue
-        if "set_tweak" in a and "value" in a:
-            tweaks[a["set_tweak"]] = a["value"]
+        if "set_tweak" in a:
+            # value is allowed to be absent in which case the value assumed
+            # should be True.
+            tweaks[a["set_tweak"]] = a.get("value", True)
     return tweaks
 
 
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index b1cac901eb..8cfcdb0573 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -66,7 +66,7 @@ REQUIREMENTS = [
     "pymacaroons>=0.13.0",
     "msgpack>=0.5.2",
     "phonenumbers>=8.2.0",
-    "prometheus_client>=0.0.18,<0.8.0",
+    "prometheus_client>=0.0.18,<0.9.0",
     # we use attr.validators.deep_iterable, which arrived in 19.1.0
     "attrs>=19.1.0",
     "netaddr>=0.7.18",
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 19b69e0e11..5ef1c6c1dc 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -30,7 +30,8 @@ REPLICATION_PREFIX = "/_synapse/replication"
 
 class ReplicationRestResource(JsonResource):
     def __init__(self, hs):
-        JsonResource.__init__(self, hs, canonical_json=False)
+        # We enable extracting jaeger contexts here as these are internal APIs.
+        super().__init__(hs, canonical_json=False, extract_context=True)
         self.register_servlets(hs)
 
     def register_servlets(self, hs):
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 9caf1e80c1..0843d28d4b 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -28,11 +28,7 @@ from synapse.api.errors import (
     RequestSendFailed,
     SynapseError,
 )
-from synapse.logging.opentracing import (
-    inject_active_span_byte_dict,
-    trace,
-    trace_servlet,
-)
+from synapse.logging.opentracing import inject_active_span_byte_dict, trace
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.util.stringutils import random_string
 
@@ -240,11 +236,8 @@ class ReplicationEndpoint(object):
         args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
         pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
 
-        handler = trace_servlet(self.__class__.__name__, extract_context=True)(handler)
-        # We don't let register paths trace this servlet using the default tracing
-        # options because we wish to extract the context explicitly.
         http_server.register_paths(
-            method, [pattern], handler, self.__class__.__name__, trace=False
+            method, [pattern], handler, self.__class__.__name__,
         )
 
     def _cached_handler(self, request, txn_id, **kwargs):
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 9db6c62bc7..525b94fd87 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -16,6 +16,7 @@
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
 from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
 from synapse.storage.data_stores.main.tags import TagsWorkerStore
 from synapse.storage.database import Database
@@ -39,12 +40,12 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
         return self._account_data_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "tag_account_data":
+        if stream_name == TagAccountDataStream.NAME:
             self._account_data_id_gen.advance(token)
             for row in rows:
                 self.get_tags_for_user.invalidate((row.user_id,))
                 self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-        elif stream_name == "account_data":
+        elif stream_name == AccountDataStream.NAME:
             self._account_data_id_gen.advance(token)
             for row in rows:
                 if not row.room_id:
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 6e7fd259d4..bd394f6b00 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -15,6 +15,7 @@
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import ToDeviceStream
 from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
 from synapse.storage.database import Database
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -44,7 +45,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
         )
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "to_device":
+        if stream_name == ToDeviceStream.NAME:
             self._device_inbox_id_gen.advance(token)
             for row in rows:
                 if row.entity.startswith("@"):
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 1851e7d525..5d210fa3a1 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -15,6 +15,7 @@
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import GroupServerStream
 from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore
 from synapse.storage.database import Database
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -38,7 +39,7 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
         return self._group_updates_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "groups":
+        if stream_name == GroupServerStream.NAME:
             self._group_updates_id_gen.advance(token)
             for row in rows:
                 self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 4e0124842d..2938cb8e43 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.replication.tcp.streams import PresenceStream
 from synapse.storage import DataStore
 from synapse.storage.data_stores.main.presence import PresenceStore
 from synapse.storage.database import Database
@@ -42,7 +43,7 @@ class SlavedPresenceStore(BaseSlavedStore):
         return self._presence_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "presence":
+        if stream_name == PresenceStream.NAME:
             self._presence_id_gen.advance(token)
             for row in rows:
                 self.presence_stream_cache.entity_has_changed(row.user_id, token)
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 6adb19463a..23ec1c5b11 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.replication.tcp.streams import PushRulesStream
 from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
 
 from .events import SlavedEventStore
@@ -30,7 +31,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
         return self._push_rules_stream_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "push_rules":
+        if stream_name == PushRulesStream.NAME:
             self._push_rules_stream_id_gen.advance(token)
             for row in rows:
                 self.get_push_rules_for_user.invalidate((row.user_id,))
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index cb78b49acb..ff449f3658 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.replication.tcp.streams import PushersStream
 from synapse.storage.data_stores.main.pusher import PusherWorkerStore
 from synapse.storage.database import Database
 
@@ -32,6 +33,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
         return self._pushers_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "pushers":
+        if stream_name == PushersStream.NAME:
             self._pushers_id_gen.advance(token)
         return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index be716cc558..6982686eb5 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -14,20 +14,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.replication.tcp.streams import ReceiptsStream
 from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
 from synapse.storage.database import Database
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
-# So, um, we want to borrow a load of functions intended for reading from
-# a DataStore, but we don't want to take functions that either write to the
-# DataStore or are cached and don't have cache invalidation logic.
-#
-# Rather than write duplicate versions of those functions, or lift them to
-# a common base class, we going to grab the underlying __func__ object from
-# the method descriptor on the DataStore and chuck them into our class.
-
 
 class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
     def __init__(self, database: Database, db_conn, hs):
@@ -52,7 +45,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
         self.get_receipts_for_room.invalidate((room_id, receipt_type))
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "receipts":
+        if stream_name == ReceiptsStream.NAME:
             self._receipts_id_gen.advance(token)
             for row in rows:
                 self.invalidate_caches_for_receipt(
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 8873bf37e5..8710207ada 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.replication.tcp.streams import PublicRoomsStream
 from synapse.storage.data_stores.main.room import RoomWorkerStore
 from synapse.storage.database import Database
 
@@ -31,7 +32,7 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
         return self._public_room_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "public_rooms":
+        if stream_name == PublicRoomsStream.NAME:
             self._public_room_id_gen.advance(token)
 
         return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index df29732f51..4985e40b1f 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -33,8 +33,8 @@ from synapse.util.async_helpers import timeout_deferred
 from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
-    from synapse.server import HomeServer
     from synapse.replication.tcp.handler import ReplicationCommandHandler
+    from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index e6a2e2598b..55b3b79008 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -13,7 +13,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
 
@@ -149,10 +148,11 @@ class ReplicationCommandHandler:
         using TCP.
         """
         if hs.config.redis.redis_enabled:
+            import txredisapi
+
             from synapse.replication.tcp.redis import (
                 RedisDirectTcpReplicationClientFactory,
             )
-            import txredisapi
 
             logger.info(
                 "Connecting to redis (host=%r port=%r)",
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index f196eff072..9076bbe9f1 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -198,26 +198,6 @@ def current_token_without_instance(
     return lambda instance_name: current_token()
 
 
-def db_query_to_update_function(
-    query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
-) -> UpdateFunction:
-    """Wraps a db query function which returns a list of rows to make it
-    suitable for use as an `update_function` for the Stream class
-    """
-
-    async def update_function(instance_name, from_token, upto_token, limit):
-        rows = await query_function(from_token, upto_token, limit)
-        updates = [(row[0], row[1:]) for row in rows]
-        limited = False
-        if len(updates) >= limit:
-            upto_token = updates[-1][0]
-            limited = True
-
-        return updates, upto_token, limited
-
-    return update_function
-
-
 def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
     """Makes a suitable function for use as an `update_function` that queries
     the master process for updates.
@@ -393,7 +373,7 @@ class PushersStream(Stream):
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_pushers_stream_token),
-            db_query_to_update_function(store.get_all_updated_pushers_rows),
+            store.get_all_updated_pushers_rows,
         )
 
 
@@ -421,26 +401,12 @@ class CachesStream(Stream):
     ROW_TYPE = CachesStreamRow
 
     def __init__(self, hs):
-        self.store = hs.get_datastore()
+        store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            self.store.get_cache_stream_token,
-            self._update_function,
-        )
-
-    async def _update_function(
-        self, instance_name: str, from_token: int, upto_token: int, limit: int
-    ):
-        rows = await self.store.get_all_updated_caches(
-            instance_name, from_token, upto_token, limit
+            store.get_cache_stream_token,
+            store.get_all_updated_caches,
         )
-        updates = [(row[0], row[1:]) for row in rows]
-        limited = False
-        if len(updates) >= limit:
-            upto_token = updates[-1][0]
-            limited = True
-
-        return updates, upto_token, limited
 
 
 class PublicRoomsStream(Stream):
@@ -465,7 +431,7 @@ class PublicRoomsStream(Stream):
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_current_public_room_stream_id),
-            db_query_to_update_function(store.get_all_new_public_rooms),
+            store.get_all_new_public_rooms,
         )
 
 
@@ -486,7 +452,7 @@ class DeviceListsStream(Stream):
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_device_stream_token),
-            db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
+            store.get_all_device_list_changes_for_remotes,
         )
 
 
@@ -504,7 +470,7 @@ class ToDeviceStream(Stream):
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_to_device_stream_token),
-            db_query_to_update_function(store.get_all_new_device_messages),
+            store.get_all_new_device_messages,
         )
 
 
@@ -524,7 +490,7 @@ class TagAccountDataStream(Stream):
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_max_account_data_stream_id),
-            db_query_to_update_function(store.get_all_updated_tags),
+            store.get_all_updated_tags,
         )
 
 
@@ -612,7 +578,7 @@ class GroupServerStream(Stream):
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_group_stream_token),
-            db_query_to_update_function(store.get_all_groups_changes),
+            store.get_all_groups_changes,
         )
 
 
@@ -630,7 +596,5 @@ class UserSignatureStream(Stream):
         super().__init__(
             hs.get_instance_name(),
             current_token_without_instance(store.get_device_stream_token),
-            db_query_to_update_function(
-                store.get_all_user_signature_changes_for_remotes
-            ),
+            store.get_all_user_signature_changes_for_remotes,
         )
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index f370390331..bdddb62ad6 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -13,7 +13,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import heapq
 from collections import Iterable
 from typing import List, Tuple, Type
@@ -22,7 +21,6 @@ import attr
 
 from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
 
-
 """Handling of the 'events' replication stream
 
 This stream contains rows of various types. Each row therefore contains a 'type'
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index bf0f9bd077..64d5c58b65 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import Awaitable, Callable, Dict, Optional
 
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
@@ -26,8 +27,9 @@ from synapse.http.servlet import (
 from synapse.http.site import SynapseRequest
 from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.rest.well_known import WellKnownBuilder
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
 from synapse.util.msisdn import phone_number_to_msisdn
+from synapse.util.threepids import canonicalise_email
 
 logger = logging.getLogger(__name__)
 
@@ -113,7 +115,7 @@ class LoginRestServlet(RestServlet):
             burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
         )
 
-    def on_GET(self, request):
+    def on_GET(self, request: SynapseRequest):
         flows = []
         if self.jwt_enabled:
             flows.append({"type": LoginRestServlet.JWT_TYPE})
@@ -141,10 +143,10 @@ class LoginRestServlet(RestServlet):
 
         return 200, {"flows": flows}
 
-    def on_OPTIONS(self, request):
+    def on_OPTIONS(self, request: SynapseRequest):
         return 200, {}
 
-    async def on_POST(self, request):
+    async def on_POST(self, request: SynapseRequest):
         self._address_ratelimiter.ratelimit(request.getClientIP())
 
         login_submission = parse_json_object_from_request(request)
@@ -153,9 +155,9 @@ class LoginRestServlet(RestServlet):
                 login_submission["type"] == LoginRestServlet.JWT_TYPE
                 or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
             ):
-                result = await self.do_jwt_login(login_submission)
+                result = await self._do_jwt_login(login_submission)
             elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
-                result = await self.do_token_login(login_submission)
+                result = await self._do_token_login(login_submission)
             else:
                 result = await self._do_other_login(login_submission)
         except KeyError:
@@ -166,14 +168,14 @@ class LoginRestServlet(RestServlet):
             result["well_known"] = well_known_data
         return 200, result
 
-    async def _do_other_login(self, login_submission):
+    async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
         """Handle non-token/saml/jwt logins
 
         Args:
             login_submission:
 
         Returns:
-            dict: HTTP response
+            HTTP response
         """
         # Log the request we got, but only certain fields to minimise the chance of
         # logging someone's password (even if they accidentally put it in the wrong
@@ -206,11 +208,14 @@ class LoginRestServlet(RestServlet):
             if medium is None or address is None:
                 raise SynapseError(400, "Invalid thirdparty identifier")
 
+            # For emails, canonicalise the address.
+            # We store all email addresses canonicalised in the DB.
+            # (See add_threepid in synapse/handlers/auth.py)
             if medium == "email":
-                # For emails, transform the address to lowercase.
-                # We store all email addreses as lowercase in the DB.
-                # (See add_threepid in synapse/handlers/auth.py)
-                address = address.lower()
+                try:
+                    address = canonicalise_email(address)
+                except ValueError as e:
+                    raise SynapseError(400, str(e))
 
             # We also apply account rate limiting using the 3PID as a key, as
             # otherwise using 3PID bypasses the ratelimiting based on user ID.
@@ -288,25 +293,30 @@ class LoginRestServlet(RestServlet):
         return result
 
     async def _complete_login(
-        self, user_id, login_submission, callback=None, create_non_existent_users=False
-    ):
+        self,
+        user_id: str,
+        login_submission: JsonDict,
+        callback: Optional[
+            Callable[[Dict[str, str]], Awaitable[Dict[str, str]]]
+        ] = None,
+        create_non_existent_users: bool = False,
+    ) -> Dict[str, str]:
         """Called when we've successfully authed the user and now need to
         actually login them in (e.g. create devices). This gets called on
-        all succesful logins.
+        all successful logins.
 
-        Applies the ratelimiting for succesful login attempts against an
+        Applies the ratelimiting for successful login attempts against an
         account.
 
         Args:
-            user_id (str): ID of the user to register.
-            login_submission (dict): Dictionary of login information.
-            callback (func|None): Callback function to run after registration.
-            create_non_existent_users (bool): Whether to create the user if
-                they don't exist. Defaults to False.
+            user_id: ID of the user to register.
+            login_submission: Dictionary of login information.
+            callback: Callback function to run after registration.
+            create_non_existent_users: Whether to create the user if they don't
+                exist. Defaults to False.
 
         Returns:
-            result (Dict[str,str]): Dictionary of account information after
-                successful registration.
+            result: Dictionary of account information after successful registration.
         """
 
         # Before we actually log them in we check if they've already logged in
@@ -340,7 +350,7 @@ class LoginRestServlet(RestServlet):
 
         return result
 
-    async def do_token_login(self, login_submission):
+    async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
         token = login_submission["token"]
         auth_handler = self.auth_handler
         user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
@@ -350,7 +360,7 @@ class LoginRestServlet(RestServlet):
         result = await self._complete_login(user_id, login_submission)
         return result
 
-    async def do_jwt_login(self, login_submission):
+    async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
         token = login_submission.get("token", None)
         if token is None:
             raise LoginError(
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 182a308eef..3767a809a4 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -30,7 +30,7 @@ from synapse.http.servlet import (
 from synapse.push.mailer import Mailer, load_jinja2_templates
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.stringutils import assert_valid_client_secret, random_string
-from synapse.util.threepids import check_3pid_allowed
+from synapse.util.threepids import canonicalise_email, check_3pid_allowed
 
 from ._base import client_patterns, interactive_auth_handler
 
@@ -83,7 +83,15 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
         client_secret = body["client_secret"]
         assert_valid_client_secret(client_secret)
 
-        email = body["email"]
+        # Canonicalise the email address. The addresses are all stored canonicalised
+        # in the database. This allows the user to reset his password without having to
+        # know the exact spelling (eg. upper and lower case) of address in the database.
+        # Stored in the database "foo@bar.com"
+        # User requests with "FOO@bar.com" would raise a Not Found error
+        try:
+            email = canonicalise_email(body["email"])
+        except ValueError as e:
+            raise SynapseError(400, str(e))
         send_attempt = body["send_attempt"]
         next_link = body.get("next_link")  # Optional param
 
@@ -94,6 +102,10 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        # The email will be sent to the stored address.
+        # This avoids a potential account hijack by requesting a password reset to
+        # an email address which is controlled by the attacker but which, after
+        # canonicalisation, matches the one in our database.
         existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
             "email", email
         )
@@ -274,10 +286,13 @@ class PasswordRestServlet(RestServlet):
                 if "medium" not in threepid or "address" not in threepid:
                     raise SynapseError(500, "Malformed threepid")
                 if threepid["medium"] == "email":
-                    # For emails, transform the address to lowercase.
-                    # We store all email addreses as lowercase in the DB.
+                    # For emails, canonicalise the address.
+                    # We store all email addresses canonicalised in the DB.
                     # (See add_threepid in synapse/handlers/auth.py)
-                    threepid["address"] = threepid["address"].lower()
+                    try:
+                        threepid["address"] = canonicalise_email(threepid["address"])
+                    except ValueError as e:
+                        raise SynapseError(400, str(e))
                 # if using email, we must know about the email they're authing with!
                 threepid_user_id = await self.datastore.get_user_id_by_threepid(
                     threepid["medium"], threepid["address"]
@@ -392,7 +407,16 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
         client_secret = body["client_secret"]
         assert_valid_client_secret(client_secret)
 
-        email = body["email"]
+        # Canonicalise the email address. The addresses are all stored canonicalised
+        # in the database.
+        # This ensures that the validation email is sent to the canonicalised address
+        # as it will later be entered into the database.
+        # Otherwise the email will be sent to "FOO@bar.com" and stored as
+        # "foo@bar.com" in database.
+        try:
+            email = canonicalise_email(body["email"])
+        except ValueError as e:
+            raise SynapseError(400, str(e))
         send_attempt = body["send_attempt"]
         next_link = body.get("next_link")  # Optional param
 
@@ -403,9 +427,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
-        existing_user_id = await self.store.get_user_id_by_threepid(
-            "email", body["email"]
-        )
+        existing_user_id = await self.store.get_user_id_by_threepid("email", email)
 
         if existing_user_id is not None:
             if self.config.request_token_inhibit_3pid_errors:
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 56a451c42f..370742ce59 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -47,7 +47,7 @@ from synapse.push.mailer import load_jinja2_templates
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.stringutils import assert_valid_client_secret, random_string
-from synapse.util.threepids import check_3pid_allowed
+from synapse.util.threepids import canonicalise_email, check_3pid_allowed
 
 from ._base import client_patterns, interactive_auth_handler
 
@@ -116,7 +116,14 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
         client_secret = body["client_secret"]
         assert_valid_client_secret(client_secret)
 
-        email = body["email"]
+        # For emails, canonicalise the address.
+        # We store all email addresses canonicalised in the DB.
+        # (See on_POST in EmailThreepidRequestTokenRestServlet
+        # in synapse/rest/client/v2_alpha/account.py)
+        try:
+            email = canonicalise_email(body["email"])
+        except ValueError as e:
+            raise SynapseError(400, str(e))
         send_attempt = body["send_attempt"]
         next_link = body.get("next_link")  # Optional param
 
@@ -128,7 +135,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
             )
 
         existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
-            "email", body["email"]
+            "email", email
         )
 
         if existing_user_id is not None:
@@ -552,6 +559,15 @@ class RegisterRestServlet(RestServlet):
                     if login_type in auth_result:
                         medium = auth_result[login_type]["medium"]
                         address = auth_result[login_type]["address"]
+                        # For emails, canonicalise the address.
+                        # We store all email addresses canonicalised in the DB.
+                        # (See on_POST in EmailThreepidRequestTokenRestServlet
+                        # in synapse/rest/client/v2_alpha/account.py)
+                        if medium == "email":
+                            try:
+                                address = canonicalise_email(address)
+                            except ValueError as e:
+                                raise SynapseError(400, str(e))
 
                         existing_user_id = await self.store.get_user_id_by_threepid(
                             medium, address
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 0a890c98cb..4386eb4e72 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -26,11 +26,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import NotFoundError, StoreError, SynapseError
 from synapse.config import ConfigError
-from synapse.http.server import (
-    DirectServeResource,
-    respond_with_html,
-    wrap_html_request_handler,
-)
+from synapse.http.server import DirectServeHtmlResource, respond_with_html
 from synapse.http.servlet import parse_string
 from synapse.types import UserID
 
@@ -48,7 +44,7 @@ else:
         return a == b
 
 
-class ConsentResource(DirectServeResource):
+class ConsentResource(DirectServeHtmlResource):
     """A twisted Resource to display a privacy policy and gather consent to it
 
     When accessed via GET, returns the privacy policy via a template.
@@ -119,7 +115,6 @@ class ConsentResource(DirectServeResource):
 
         self._hmac_secret = hs.config.form_secret.encode("utf-8")
 
-    @wrap_html_request_handler
     async def _async_render_GET(self, request):
         """
         Args:
@@ -160,7 +155,6 @@ class ConsentResource(DirectServeResource):
         except TemplateNotFound:
             raise NotFoundError("Unknown policy version")
 
-    @wrap_html_request_handler
     async def _async_render_POST(self, request):
         """
         Args:
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index ab671f7334..e149ac1733 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -20,17 +20,13 @@ from signedjson.sign import sign_json
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.crypto.keyring import ServerKeyFetcher
-from synapse.http.server import (
-    DirectServeResource,
-    respond_with_json_bytes,
-    wrap_json_request_handler,
-)
+from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes
 from synapse.http.servlet import parse_integer, parse_json_object_from_request
 
 logger = logging.getLogger(__name__)
 
 
-class RemoteKey(DirectServeResource):
+class RemoteKey(DirectServeJsonResource):
     """HTTP resource for retreiving the TLS certificate and NACL signature
     verification keys for a collection of servers. Checks that the reported
     X.509 TLS certificate matches the one used in the HTTPS connection. Checks
@@ -92,13 +88,14 @@ class RemoteKey(DirectServeResource):
     isLeaf = True
 
     def __init__(self, hs):
+        super().__init__()
+
         self.fetcher = ServerKeyFetcher(hs)
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
         self.federation_domain_whitelist = hs.config.federation_domain_whitelist
         self.config = hs.config
 
-    @wrap_json_request_handler
     async def _async_render_GET(self, request):
         if len(request.postpath) == 1:
             (server,) = request.postpath
@@ -115,7 +112,6 @@ class RemoteKey(DirectServeResource):
 
         await self.query_keys(request, query, query_remote_on_cache_miss=True)
 
-    @wrap_json_request_handler
     async def _async_render_POST(self, request):
         content = parse_json_object_from_request(request)
 
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index 9f747de263..68dd2a1c8a 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -14,16 +14,10 @@
 # limitations under the License.
 #
 
-from twisted.web.server import NOT_DONE_YET
+from synapse.http.server import DirectServeJsonResource, respond_with_json
 
-from synapse.http.server import (
-    DirectServeResource,
-    respond_with_json,
-    wrap_json_request_handler,
-)
 
-
-class MediaConfigResource(DirectServeResource):
+class MediaConfigResource(DirectServeJsonResource):
     isLeaf = True
 
     def __init__(self, hs):
@@ -33,11 +27,9 @@ class MediaConfigResource(DirectServeResource):
         self.auth = hs.get_auth()
         self.limits_dict = {"m.upload.size": config.max_upload_size}
 
-    @wrap_json_request_handler
     async def _async_render_GET(self, request):
         await self.auth.get_user_by_req(request)
         respond_with_json(request, 200, self.limits_dict, send_cors=True)
 
-    def render_OPTIONS(self, request):
+    async def _async_render_OPTIONS(self, request):
         respond_with_json(request, 200, {}, send_cors=True)
-        return NOT_DONE_YET
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 24d3ae5bbc..d3d8457303 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -15,18 +15,14 @@
 import logging
 
 import synapse.http.servlet
-from synapse.http.server import (
-    DirectServeResource,
-    set_cors_headers,
-    wrap_json_request_handler,
-)
+from synapse.http.server import DirectServeJsonResource, set_cors_headers
 
 from ._base import parse_media_id, respond_404
 
 logger = logging.getLogger(__name__)
 
 
-class DownloadResource(DirectServeResource):
+class DownloadResource(DirectServeJsonResource):
     isLeaf = True
 
     def __init__(self, hs, media_repo):
@@ -34,10 +30,6 @@ class DownloadResource(DirectServeResource):
         self.media_repo = media_repo
         self.server_name = hs.hostname
 
-        # this is expected by @wrap_json_request_handler
-        self.clock = hs.get_clock()
-
-    @wrap_json_request_handler
     async def _async_render_GET(self, request):
         set_cors_headers(request)
         request.setHeader(
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index b4645cd608..e52c86c798 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -34,10 +34,9 @@ from twisted.internet.error import DNSLookupError
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.client import SimpleHttpClient
 from synapse.http.server import (
-    DirectServeResource,
+    DirectServeJsonResource,
     respond_with_json,
     respond_with_json_bytes,
-    wrap_json_request_handler,
 )
 from synapse.http.servlet import parse_integer, parse_string
 from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -58,7 +57,7 @@ OG_TAG_NAME_MAXLEN = 50
 OG_TAG_VALUE_MAXLEN = 1000
 
 
-class PreviewUrlResource(DirectServeResource):
+class PreviewUrlResource(DirectServeJsonResource):
     isLeaf = True
 
     def __init__(self, hs, media_repo, media_storage):
@@ -108,11 +107,10 @@ class PreviewUrlResource(DirectServeResource):
                 self._start_expire_url_cache_data, 10 * 1000
             )
 
-    def render_OPTIONS(self, request):
+    async def _async_render_OPTIONS(self, request):
         request.setHeader(b"Allow", b"OPTIONS, GET")
-        return respond_with_json(request, 200, {}, send_cors=True)
+        respond_with_json(request, 200, {}, send_cors=True)
 
-    @wrap_json_request_handler
     async def _async_render_GET(self, request):
 
         # XXX: if get_user_by_req fails, what should we do in an async render?
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 0b87220234..a83535b97b 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,11 +16,7 @@
 
 import logging
 
-from synapse.http.server import (
-    DirectServeResource,
-    set_cors_headers,
-    wrap_json_request_handler,
-)
+from synapse.http.server import DirectServeJsonResource, set_cors_headers
 from synapse.http.servlet import parse_integer, parse_string
 
 from ._base import (
@@ -34,7 +30,7 @@ from ._base import (
 logger = logging.getLogger(__name__)
 
 
-class ThumbnailResource(DirectServeResource):
+class ThumbnailResource(DirectServeJsonResource):
     isLeaf = True
 
     def __init__(self, hs, media_repo, media_storage):
@@ -45,9 +41,7 @@ class ThumbnailResource(DirectServeResource):
         self.media_storage = media_storage
         self.dynamic_thumbnails = hs.config.dynamic_thumbnails
         self.server_name = hs.hostname
-        self.clock = hs.get_clock()
 
-    @wrap_json_request_handler
     async def _async_render_GET(self, request):
         set_cors_headers(request)
         server_name, media_id, _ = parse_media_id(request)
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index c234ea7421..7126997134 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -12,11 +12,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from io import BytesIO
 
-import PIL.Image as Image
+from PIL import Image as Image
 
 logger = logging.getLogger(__name__)
 
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 83d005812d..3ebf7a68e6 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -15,20 +15,14 @@
 
 import logging
 
-from twisted.web.server import NOT_DONE_YET
-
 from synapse.api.errors import Codes, SynapseError
-from synapse.http.server import (
-    DirectServeResource,
-    respond_with_json,
-    wrap_json_request_handler,
-)
+from synapse.http.server import DirectServeJsonResource, respond_with_json
 from synapse.http.servlet import parse_string
 
 logger = logging.getLogger(__name__)
 
 
-class UploadResource(DirectServeResource):
+class UploadResource(DirectServeJsonResource):
     isLeaf = True
 
     def __init__(self, hs, media_repo):
@@ -43,11 +37,9 @@ class UploadResource(DirectServeResource):
         self.max_upload_size = hs.config.max_upload_size
         self.clock = hs.get_clock()
 
-    def render_OPTIONS(self, request):
+    async def _async_render_OPTIONS(self, request):
         respond_with_json(request, 200, {}, send_cors=True)
-        return NOT_DONE_YET
 
-    @wrap_json_request_handler
     async def _async_render_POST(self, request):
         requester = await self.auth.get_user_by_req(request)
         # TODO: The checks here are a bit late. The content will have
diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/oidc/callback_resource.py
index c03194f001..f7a0bc4bdb 100644
--- a/synapse/rest/oidc/callback_resource.py
+++ b/synapse/rest/oidc/callback_resource.py
@@ -14,18 +14,17 @@
 # limitations under the License.
 import logging
 
-from synapse.http.server import DirectServeResource, wrap_html_request_handler
+from synapse.http.server import DirectServeHtmlResource
 
 logger = logging.getLogger(__name__)
 
 
-class OIDCCallbackResource(DirectServeResource):
+class OIDCCallbackResource(DirectServeHtmlResource):
     isLeaf = 1
 
     def __init__(self, hs):
         super().__init__()
         self._oidc_handler = hs.get_oidc_handler()
 
-    @wrap_html_request_handler
     async def _async_render_GET(self, request):
-        return await self._oidc_handler.handle_oidc_callback(request)
+        await self._oidc_handler.handle_oidc_callback(request)
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py
index 75e58043b4..c10188a5d7 100644
--- a/synapse/rest/saml2/response_resource.py
+++ b/synapse/rest/saml2/response_resource.py
@@ -16,10 +16,10 @@
 from twisted.python import failure
 
 from synapse.api.errors import SynapseError
-from synapse.http.server import DirectServeResource, return_html_error
+from synapse.http.server import DirectServeHtmlResource, return_html_error
 
 
-class SAML2ResponseResource(DirectServeResource):
+class SAML2ResponseResource(DirectServeHtmlResource):
     """A Twisted web resource which handles the SAML response"""
 
     isLeaf = 1
diff --git a/synapse/secrets.py b/synapse/secrets.py
index 0b327a0f82..5f43f81eb0 100644
--- a/synapse/secrets.py
+++ b/synapse/secrets.py
@@ -19,7 +19,6 @@ Injectable secrets module for Synapse.
 See https://docs.python.org/3/library/secrets.html#module-secrets for the API
 used in Python 3.6, and the API emulated in Python 2.7.
 """
-
 import sys
 
 # secrets is available since python 3.6
@@ -31,8 +30,8 @@ if sys.version_info[0:2] >= (3, 6):
 
 
 else:
-    import os
     import binascii
+    import os
 
     class Secrets(object):
         def token_bytes(self, nbytes=32):
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index eac5a4e55b..f39f556c20 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -16,10 +16,12 @@
 
 import itertools
 import logging
-from typing import Any, Iterable, Optional, Tuple
+from typing import Any, Iterable, List, Optional, Tuple
 
 from synapse.api.constants import EventTypes
+from synapse.replication.tcp.streams import BackfillStream, CachesStream
 from synapse.replication.tcp.streams.events import (
+    EventsStream,
     EventsStreamCurrentStateRow,
     EventsStreamEventRow,
 )
@@ -44,13 +46,30 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
     async def get_all_updated_caches(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ):
-        """Fetches cache invalidation rows between the two given IDs written
-        by the given instance. Returns at most `limit` rows.
+    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+        """Get updates for caches replication stream.
+
+        Args:
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
+        Returns:
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
         """
 
         if last_id == current_id:
-            return []
+            return [], current_id, False
 
         def get_all_updated_caches_txn(txn):
             # We purposefully don't bound by the current token, as we want to
@@ -64,17 +83,24 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                 LIMIT ?
             """
             txn.execute(sql, (last_id, instance_name, limit))
-            return txn.fetchall()
+            updates = [(row[0], row[1:]) for row in txn]
+            limited = False
+            upto_token = current_id
+            if len(updates) >= limit:
+                upto_token = updates[-1][0]
+                limited = True
+
+            return updates, upto_token, limited
 
         return await self.db.runInteraction(
             "get_all_updated_caches", get_all_updated_caches_txn
         )
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "events":
+        if stream_name == EventsStream.NAME:
             for row in rows:
                 self._process_event_stream_row(token, row)
-        elif stream_name == "backfill":
+        elif stream_name == BackfillStream.NAME:
             for row in rows:
                 self._invalidate_caches_for_event(
                     -token,
@@ -86,7 +112,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                     row.relates_to,
                     backfilled=True,
                 )
-        elif stream_name == "caches":
+        elif stream_name == CachesStream.NAME:
             if self._cache_id_gen:
                 self._cache_id_gen.advance(instance_name, token)
 
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 9a1178fb39..d313b9705f 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import List, Tuple
 
 from canonicaljson import json
 
@@ -207,31 +208,46 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
         )
 
-    def get_all_new_device_messages(self, last_pos, current_pos, limit):
-        """
+    async def get_all_new_device_messages(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+        """Get updates for to device replication stream.
+
         Args:
-            last_pos(int):
-            current_pos(int):
-            limit(int):
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
         Returns:
-            A deferred list of rows from the device inbox
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
         """
-        if last_pos == current_pos:
-            return defer.succeed([])
+
+        if last_id == current_id:
+            return [], current_id, False
 
         def get_all_new_device_messages_txn(txn):
             # We limit like this as we might have multiple rows per stream_id, and
             # we want to make sure we always get all entries for any stream_id
             # we return.
-            upper_pos = min(current_pos, last_pos + limit)
+            upper_pos = min(current_id, last_id + limit)
             sql = (
                 "SELECT max(stream_id), user_id"
                 " FROM device_inbox"
                 " WHERE ? < stream_id AND stream_id <= ?"
                 " GROUP BY user_id"
             )
-            txn.execute(sql, (last_pos, upper_pos))
-            rows = txn.fetchall()
+            txn.execute(sql, (last_id, upper_pos))
+            updates = [(row[0], row[1:]) for row in txn]
 
             sql = (
                 "SELECT max(stream_id), destination"
@@ -239,15 +255,21 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 " WHERE ? < stream_id AND stream_id <= ?"
                 " GROUP BY destination"
             )
-            txn.execute(sql, (last_pos, upper_pos))
-            rows.extend(txn)
+            txn.execute(sql, (last_id, upper_pos))
+            updates.extend((row[0], row[1:]) for row in txn)
 
             # Order by ascending stream ordering
-            rows.sort()
+            updates.sort()
 
-            return rows
+            limited = False
+            upto_token = current_id
+            if len(updates) >= limit:
+                upto_token = updates[-1][0]
+                limited = True
 
-        return self.db.runInteraction(
+            return updates, upto_token, limited
+
+        return await self.db.runInteraction(
             "get_all_new_device_messages", get_all_new_device_messages_txn
         )
 
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 0ff0542453..343cf9a2d5 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -582,32 +582,58 @@ class DeviceWorkerStore(SQLBaseStore):
             return set()
 
     async def get_all_device_list_changes_for_remotes(
-        self, from_key: int, to_key: int, limit: int,
-    ) -> List[Tuple[int, str]]:
-        """Return a list of `(stream_id, entity)` which is the combined list of
-        changes to devices and which destinations need to be poked. Entity is
-        either a user ID (starting with '@') or a remote destination.
-        """
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+        """Get updates for device lists replication stream.
 
-        # This query Does The Right Thing where it'll correctly apply the
-        # bounds to the inner queries.
-        sql = """
-            SELECT stream_id, entity FROM (
-                SELECT stream_id, user_id AS entity FROM device_lists_stream
-                UNION ALL
-                SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
-            ) AS e
-            WHERE ? < stream_id AND stream_id <= ?
-            LIMIT ?
+        Args:
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
+        Returns:
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
         """
 
-        return await self.db.execute(
+        if last_id == current_id:
+            return [], current_id, False
+
+        def _get_all_device_list_changes_for_remotes(txn):
+            # This query Does The Right Thing where it'll correctly apply the
+            # bounds to the inner queries.
+            sql = """
+                SELECT stream_id, entity FROM (
+                    SELECT stream_id, user_id AS entity FROM device_lists_stream
+                    UNION ALL
+                    SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+                ) AS e
+                WHERE ? < stream_id AND stream_id <= ?
+                LIMIT ?
+            """
+
+            txn.execute(sql, (last_id, current_id, limit))
+            updates = [(row[0], row[1:]) for row in txn]
+            limited = False
+            upto_token = current_id
+            if len(updates) >= limit:
+                upto_token = updates[-1][0]
+                limited = True
+
+            return updates, upto_token, limited
+
+        return await self.db.runInteraction(
             "get_all_device_list_changes_for_remotes",
-            None,
-            sql,
-            from_key,
-            to_key,
-            limit,
+            _get_all_device_list_changes_for_remotes,
         )
 
     @cached(max_entries=10000)
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 1a0842d4b0..6c3cff82e1 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -14,7 +14,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Dict, List
+from typing import Dict, List, Tuple
 
 from canonicaljson import encode_canonical_json, json
 
@@ -479,34 +479,61 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
 
         return result
 
-    def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit):
-        """Return a list of changes from the user signature stream to notify remotes.
+    async def get_all_user_signature_changes_for_remotes(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+        """Get updates for groups replication stream.
+
         Note that the user signature stream represents when a user signs their
         device with their user-signing key, which is not published to other
         users or servers, so no `destination` is needed in the returned
         list. However, this is needed to poke workers.
 
         Args:
-            from_key (int): the stream ID to start at (exclusive)
-            to_key (int): the stream ID to end at (inclusive)
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
 
         Returns:
-            Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
-        """
-        sql = """
-            SELECT stream_id, from_user_id AS user_id
-            FROM user_signature_stream
-            WHERE ? < stream_id AND stream_id <= ?
-            ORDER BY stream_id ASC
-            LIMIT ?
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
         """
-        return self.db.execute(
+
+        if last_id == current_id:
+            return [], current_id, False
+
+        def _get_all_user_signature_changes_for_remotes_txn(txn):
+            sql = """
+                SELECT stream_id, from_user_id AS user_id
+                FROM user_signature_stream
+                WHERE ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC
+                LIMIT ?
+            """
+            txn.execute(sql, (last_id, current_id, limit))
+
+            updates = [(row[0], (row[1:])) for row in txn]
+
+            limited = False
+            upto_token = current_id
+            if len(updates) >= limit:
+                upto_token = updates[-1][0]
+                limited = True
+
+            return updates, upto_token, limited
+
+        return await self.db.runInteraction(
             "get_all_user_signature_changes_for_remotes",
-            None,
-            sql,
-            from_key,
-            to_key,
-            limit,
+            _get_all_user_signature_changes_for_remotes_txn,
         )
 
 
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index cfd24d2f06..a18317366c 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -14,7 +14,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import itertools
 import logging
 from collections import OrderedDict, namedtuple
@@ -28,12 +27,7 @@ from prometheus_client import Counter
 from twisted.internet import defer
 
 import synapse.metrics
-from synapse.api.constants import (
-    EventContentFields,
-    EventTypes,
-    Membership,
-    RelationTypes,
-)
+from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.api.room_versions import RoomVersions
 from synapse.crypto.event_signing import compute_event_reference_hash
 from synapse.events import EventBase  # noqa: F401
@@ -48,8 +42,8 @@ from synapse.util.frozenutils import frozendict_json_encoder
 from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
-    from synapse.storage.data_stores.main import DataStore
     from synapse.server import HomeServer
+    from synapse.storage.data_stores.main import DataStore
 
 
 logger = logging.getLogger(__name__)
@@ -820,7 +814,6 @@ class PersistEventsStore:
             "event_reference_hashes",
             "event_search",
             "event_to_state_groups",
-            "local_invites",
             "state_events",
             "rejections",
             "redactions",
@@ -1197,65 +1190,27 @@ class PersistEventsStore:
                 (event.state_key,),
             )
 
-            # We update the local_invites table only if the event is "current",
-            # i.e., its something that has just happened. If the event is an
-            # outlier it is only current if its an "out of band membership",
-            # like a remote invite or a rejection of a remote invite.
-            is_new_state = not backfilled and (
-                not event.internal_metadata.is_outlier()
-                or event.internal_metadata.is_out_of_band_membership()
-            )
-            is_mine = self.is_mine_id(event.state_key)
-            if is_new_state and is_mine:
-                if event.membership == Membership.INVITE:
-                    self.db.simple_insert_txn(
-                        txn,
-                        table="local_invites",
-                        values={
-                            "event_id": event.event_id,
-                            "invitee": event.state_key,
-                            "inviter": event.sender,
-                            "room_id": event.room_id,
-                            "stream_id": event.internal_metadata.stream_ordering,
-                        },
-                    )
-                else:
-                    sql = (
-                        "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
-                        " room_id = ? AND invitee = ? AND locally_rejected is NULL"
-                        " AND replaced_by is NULL"
-                    )
-
-                    txn.execute(
-                        sql,
-                        (
-                            event.internal_metadata.stream_ordering,
-                            event.event_id,
-                            event.room_id,
-                            event.state_key,
-                        ),
-                    )
-
-                # We also update the `local_current_membership` table with
-                # latest invite info. This will usually get updated by the
-                # `current_state_events` handling, unless its an outlier.
-                if event.internal_metadata.is_outlier():
-                    # This should only happen for out of band memberships, so
-                    # we add a paranoia check.
-                    assert event.internal_metadata.is_out_of_band_membership()
-
-                    self.db.simple_upsert_txn(
-                        txn,
-                        table="local_current_membership",
-                        keyvalues={
-                            "room_id": event.room_id,
-                            "user_id": event.state_key,
-                        },
-                        values={
-                            "event_id": event.event_id,
-                            "membership": event.membership,
-                        },
-                    )
+            # We update the local_current_membership table only if the event is
+            # "current", i.e., its something that has just happened.
+            #
+            # This will usually get updated by the `current_state_events` handling,
+            # unless its an outlier, and an outlier is only "current" if it's an "out of
+            # band membership", like a remote invite or a rejection of a remote invite.
+            if (
+                self.is_mine_id(event.state_key)
+                and not backfilled
+                and event.internal_metadata.is_outlier()
+                and event.internal_metadata.is_out_of_band_membership()
+            ):
+                self.db.simple_upsert_txn(
+                    txn,
+                    table="local_current_membership",
+                    keyvalues={"room_id": event.room_id, "user_id": event.state_key},
+                    values={
+                        "event_id": event.event_id,
+                        "membership": event.membership,
+                    },
+                )
 
     def _handle_event_relations(self, txn, event):
         """Handles inserting relation data during peristence of events
@@ -1592,16 +1547,8 @@ class PersistEventsStore:
         create a leave event for it.
         """
 
-        sql = (
-            "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
-            " room_id = ? AND invitee = ? AND locally_rejected is NULL"
-            " AND replaced_by is NULL"
-        )
-
         def f(txn, stream_ordering):
-            txn.execute(sql, (stream_ordering, True, room_id, user_id))
-
-            # We also clear this entry from `local_current_membership`.
+            # Clear this entry from `local_current_membership`.
             # Ideally we'd point to a leave event, but we don't have one, so
             # nevermind.
             self.db.simple_delete_txn(
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index a48c7a96ca..01cad7d4fa 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -38,6 +38,8 @@ from synapse.events.utils import prune_event
 from synapse.logging.context import PreserveLoggingContext, current_context
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import BackfillStream
+from synapse.replication.tcp.streams.events import EventsStream
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.database import Database
 from synapse.storage.util.id_generators import StreamIdGenerator
@@ -80,10 +82,7 @@ class EventsWorkerStore(SQLBaseStore):
             # We are the process in charge of generating stream ids for events,
             # so instantiate ID generators based on the database
             self._stream_id_gen = StreamIdGenerator(
-                db_conn,
-                "events",
-                "stream_ordering",
-                extra_tables=[("local_invites", "stream_id")],
+                db_conn, "events", "stream_ordering",
             )
             self._backfill_id_gen = StreamIdGenerator(
                 db_conn,
@@ -113,9 +112,9 @@ class EventsWorkerStore(SQLBaseStore):
         self._event_fetch_ongoing = 0
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "events":
+        if stream_name == EventsStream.NAME:
             self._stream_id_gen.advance(token)
-        elif stream_name == "backfill":
+        elif stream_name == BackfillStream.NAME:
             self._backfill_id_gen.advance(-token)
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index fb1361f1c1..4fb9f9850c 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -14,6 +14,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import List, Tuple
+
 from canonicaljson import json
 
 from twisted.internet import defer
@@ -526,13 +528,35 @@ class GroupServerWorkerStore(SQLBaseStore):
             "get_groups_changes_for_user", _get_groups_changes_for_user_txn
         )
 
-    def get_all_groups_changes(self, from_token, to_token, limit):
-        from_token = int(from_token)
-        has_changed = self._group_updates_stream_cache.has_any_entity_changed(
-            from_token
-        )
+    async def get_all_groups_changes(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+        """Get updates for groups replication stream.
+
+        Args:
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
+        Returns:
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
+        """
+
+        last_id = int(last_id)
+        has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)
+
         if not has_changed:
-            return defer.succeed([])
+            return [], current_id, False
 
         def _get_all_groups_changes_txn(txn):
             sql = """
@@ -541,13 +565,21 @@ class GroupServerWorkerStore(SQLBaseStore):
                 WHERE ? < stream_id AND stream_id <= ?
                 LIMIT ?
             """
-            txn.execute(sql, (from_token, to_token, limit))
-            return [
-                (stream_id, group_id, user_id, gtype, json.loads(content_json))
+            txn.execute(sql, (last_id, current_id, limit))
+            updates = [
+                (stream_id, (group_id, user_id, gtype, json.loads(content_json)))
                 for stream_id, group_id, user_id, gtype, content_json in txn
             ]
 
-        return self.db.runInteraction(
+            limited = False
+            upto_token = current_id
+            if len(updates) >= limit:
+                upto_token = updates[-1][0]
+                limited = True
+
+            return updates, upto_token, limited
+
+        return await self.db.runInteraction(
             "get_all_groups_changes", _get_all_groups_changes_txn
         )
 
diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/data_stores/main/purge_events.py
index a93e1ef198..6546569139 100644
--- a/synapse/storage/data_stores/main/purge_events.py
+++ b/synapse/storage/data_stores/main/purge_events.py
@@ -361,7 +361,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
             "event_push_summary",
             "pusher_throttle",
             "group_summary_rooms",
-            "local_invites",
             "room_account_data",
             "room_tags",
             "local_current_membership",
diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py
index 547b9d69cb..5461016240 100644
--- a/synapse/storage/data_stores/main/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import Iterable, Iterator
+from typing import Iterable, Iterator, List, Tuple
 
 from canonicaljson import encode_canonical_json, json
 
@@ -98,77 +98,69 @@ class PusherWorkerStore(SQLBaseStore):
         rows = yield self.db.runInteraction("get_all_pushers", get_pushers)
         return rows
 
-    def get_all_updated_pushers(self, last_id, current_id, limit):
-        if last_id == current_id:
-            return defer.succeed(([], []))
-
-        def get_all_updated_pushers_txn(txn):
-            sql = (
-                "SELECT id, user_name, access_token, profile_tag, kind,"
-                " app_id, app_display_name, device_display_name, pushkey, ts,"
-                " lang, data"
-                " FROM pushers"
-                " WHERE ? < id AND id <= ?"
-                " ORDER BY id ASC LIMIT ?"
-            )
-            txn.execute(sql, (last_id, current_id, limit))
-            updated = txn.fetchall()
-
-            sql = (
-                "SELECT stream_id, user_id, app_id, pushkey"
-                " FROM deleted_pushers"
-                " WHERE ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC LIMIT ?"
-            )
-            txn.execute(sql, (last_id, current_id, limit))
-            deleted = txn.fetchall()
+    async def get_all_updated_pushers_rows(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+        """Get updates for pushers replication stream.
 
-            return updated, deleted
+        Args:
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
 
-        return self.db.runInteraction(
-            "get_all_updated_pushers", get_all_updated_pushers_txn
-        )
+        Returns:
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
 
-    def get_all_updated_pushers_rows(self, last_id, current_id, limit):
-        """Get all the pushers that have changed between the given tokens.
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
 
-        Returns:
-            Deferred(list(tuple)): each tuple consists of:
-                stream_id (str)
-                user_id (str)
-                app_id (str)
-                pushkey (str)
-                was_deleted (bool): whether the pusher was added/updated (False)
-                    or deleted (True)
+            The updates are a list of 2-tuples of stream ID and the row data
         """
 
         if last_id == current_id:
-            return defer.succeed([])
+            return [], current_id, False
 
         def get_all_updated_pushers_rows_txn(txn):
-            sql = (
-                "SELECT id, user_name, app_id, pushkey"
-                " FROM pushers"
-                " WHERE ? < id AND id <= ?"
-                " ORDER BY id ASC LIMIT ?"
-            )
+            sql = """
+                SELECT id, user_name, app_id, pushkey
+                FROM pushers
+                WHERE ? < id AND id <= ?
+                ORDER BY id ASC LIMIT ?
+            """
             txn.execute(sql, (last_id, current_id, limit))
-            results = [list(row) + [False] for row in txn]
-
-            sql = (
-                "SELECT stream_id, user_id, app_id, pushkey"
-                " FROM deleted_pushers"
-                " WHERE ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC LIMIT ?"
-            )
+            updates = [
+                (stream_id, (user_name, app_id, pushkey, False))
+                for stream_id, user_name, app_id, pushkey in txn
+            ]
+
+            sql = """
+                SELECT stream_id, user_id, app_id, pushkey
+                FROM deleted_pushers
+                WHERE ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC LIMIT ?
+            """
             txn.execute(sql, (last_id, current_id, limit))
+            updates.extend(
+                (stream_id, (user_name, app_id, pushkey, True))
+                for stream_id, user_name, app_id, pushkey in txn
+            )
+
+            updates.sort()  # Sort so that they're ordered by stream id
 
-            results.extend(list(row) + [True] for row in txn)
-            results.sort()  # Sort so that they're ordered by stream id
+            limited = False
+            upper_bound = current_id
+            if len(updates) >= limit:
+                limited = True
+                upper_bound = updates[-1][0]
 
-            return results
+            return updates, upper_bound, limited
 
-        return self.db.runInteraction(
+        return await self.db.runInteraction(
             "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
         )
 
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index 13e366536a..c473cf158f 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -803,7 +803,32 @@ class RoomWorkerStore(SQLBaseStore):
 
         return total_media_quarantined
 
-    def get_all_new_public_rooms(self, prev_id, current_id, limit):
+    async def get_all_new_public_rooms(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+        """Get updates for public rooms replication stream.
+
+        Args:
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
+        Returns:
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
+        """
+        if last_id == current_id:
+            return [], current_id, False
+
         def get_all_new_public_rooms(txn):
             sql = """
                 SELECT stream_id, room_id, visibility, appservice_id, network_id
@@ -813,13 +838,17 @@ class RoomWorkerStore(SQLBaseStore):
                 LIMIT ?
             """
 
-            txn.execute(sql, (prev_id, current_id, limit))
-            return txn.fetchall()
+            txn.execute(sql, (last_id, current_id, limit))
+            updates = [(row[0], row[1:]) for row in txn]
+            limited = False
+            upto_token = current_id
+            if len(updates) >= limit:
+                upto_token = updates[-1][0]
+                limited = True
 
-        if prev_id == current_id:
-            return defer.succeed([])
+            return updates, upto_token, limited
 
-        return self.db.runInteraction(
+        return await self.db.runInteraction(
             "get_all_new_public_rooms", get_all_new_public_rooms
         )
 
diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py
index f8c776be3f..290317fd94 100644
--- a/synapse/storage/data_stores/main/tags.py
+++ b/synapse/storage/data_stores/main/tags.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+from typing import List, Tuple
 
 from canonicaljson import json
 
@@ -53,18 +54,32 @@ class TagsWorkerStore(AccountDataWorkerStore):
 
         return deferred
 
-    @defer.inlineCallbacks
-    def get_all_updated_tags(self, last_id, current_id, limit):
-        """Get all the client tags that have changed on the server
+    async def get_all_updated_tags(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+        """Get updates for tags replication stream.
+
         Args:
-            last_id(int): The position to fetch from.
-            current_id(int): The position to fetch up to.
+            instance_name: The writer we want to fetch updates from. Unused
+                here since there is only ever one writer.
+            last_id: The token to fetch updates from. Exclusive.
+            current_id: The token to fetch updates up to. Inclusive.
+            limit: The requested limit for the number of rows to return. The
+                function may return more or fewer rows.
+
         Returns:
-            A deferred list of tuples of stream_id int, user_id string,
-            room_id string, tag string and content string.
+            A tuple consisting of: the updates, a token to use to fetch
+            subsequent updates, and whether we returned fewer rows than exists
+            between the requested tokens due to the limit.
+
+            The token returned can be used in a subsequent call to this
+            function to get further updatees.
+
+            The updates are a list of 2-tuples of stream ID and the row data
         """
+
         if last_id == current_id:
-            return []
+            return [], current_id, False
 
         def get_all_updated_tags_txn(txn):
             sql = (
@@ -76,7 +91,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             txn.execute(sql, (last_id, current_id, limit))
             return txn.fetchall()
 
-        tag_ids = yield self.db.runInteraction(
+        tag_ids = await self.db.runInteraction(
             "get_all_updated_tags", get_all_updated_tags_txn
         )
 
@@ -89,21 +104,27 @@ class TagsWorkerStore(AccountDataWorkerStore):
                 for tag, content in txn:
                     tags.append(json.dumps(tag) + ":" + content)
                 tag_json = "{" + ",".join(tags) + "}"
-                results.append((stream_id, user_id, room_id, tag_json))
+                results.append((stream_id, (user_id, room_id, tag_json)))
 
             return results
 
         batch_size = 50
         results = []
         for i in range(0, len(tag_ids), batch_size):
-            tags = yield self.db.runInteraction(
+            tags = await self.db.runInteraction(
                 "get_all_updated_tag_content",
                 get_tag_content,
                 tag_ids[i : i + batch_size],
             )
             results.extend(tags)
 
-        return results
+        limited = False
+        upto_token = current_id
+        if len(results) >= limit:
+            upto_token = results[-1][0]
+            limited = True
+
+        return results, upto_token, limited
 
     @defer.inlineCallbacks
     def get_updated_tags(self, user_id, stream_id):
diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py
index ec2f38c373..4c044b1a15 100644
--- a/synapse/storage/data_stores/main/ui_auth.py
+++ b/synapse/storage/data_stores/main/ui_auth.py
@@ -17,10 +17,10 @@ from typing import Any, Dict, Optional, Union
 
 import attr
 
-import synapse.util.stringutils as stringutils
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore
 from synapse.types import JsonDict
+from synapse.util import stringutils as stringutils
 
 
 @attr.s
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 6c7d08a6f2..a31588080d 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -92,7 +92,7 @@ class PostgresEngine(BaseDatabaseEngine):
             errors.append("    - 'COLLATE' is set to %r. Should be 'C'" % (collation,))
 
         if ctype != "C":
-            errors.append("    - 'CTYPE' is set to %r. Should be 'C'" % (collation,))
+            errors.append("    - 'CTYPE' is set to %r. Should be 'C'" % (ctype,))
 
         if errors:
             raise IncorrectDatabaseSetup(
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index daff81c5ee..2d2b560e74 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -12,12 +12,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 from typing import Any, Iterable, Iterator, List, Tuple
 
 from typing_extensions import Protocol
 
-
 """
 Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
 """
diff --git a/synapse/types.py b/synapse/types.py
index acf60baddc..238b938064 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -29,7 +29,7 @@ from synapse.api.errors import Codes, SynapseError
 if sys.version_info[:3] >= (3, 6, 0):
     from typing import Collection
 else:
-    from typing import Sized, Iterable, Container
+    from typing import Container, Iterable, Sized
 
     T_co = TypeVar("T_co", covariant=True)
 
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 3ec1dfb0c2..43c2e0ac23 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -48,3 +48,26 @@ def check_3pid_allowed(hs, medium, address):
         return True
 
     return False
+
+
+def canonicalise_email(address: str) -> str:
+    """'Canonicalise' email address
+    Case folding of local part of email address and lowercase domain part
+    See MSC2265, https://github.com/matrix-org/matrix-doc/pull/2265
+
+    Args:
+        address: email address to be canonicalised
+    Returns:
+        The canonical form of the email address
+    Raises:
+        ValueError if the address could not be parsed.
+    """
+
+    address = address.strip()
+
+    parts = address.split("@")
+    if len(parts) != 2:
+        logger.debug("Couldn't parse email address %s", address)
+        raise ValueError("Unable to parse email address")
+
+    return parts[0].casefold() + "@" + parts[1].lower()