summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/constants.py1
-rw-r--r--synapse/api/errors.py21
-rw-r--r--synapse/app/generic_worker.py70
-rw-r--r--synapse/config/_base.py2
-rw-r--r--synapse/config/captcha.py5
-rw-r--r--synapse/config/database.py160
-rw-r--r--synapse/config/metrics.py2
-rw-r--r--synapse/config/password.py39
-rw-r--r--synapse/config/registration.py27
-rw-r--r--synapse/crypto/keyring.py4
-rw-r--r--synapse/federation/federation_base.py28
-rw-r--r--synapse/federation/federation_client.py19
-rw-r--r--synapse/federation/federation_server.py8
-rw-r--r--synapse/federation/send_queue.py2
-rw-r--r--synapse/federation/sender/__init__.py9
-rw-r--r--synapse/handlers/auth.py151
-rw-r--r--synapse/handlers/cas_handler.py204
-rw-r--r--synapse/handlers/device.py14
-rw-r--r--synapse/handlers/directory.py6
-rw-r--r--synapse/handlers/federation.py25
-rw-r--r--synapse/handlers/message.py57
-rw-r--r--synapse/handlers/password_policy.py93
-rw-r--r--synapse/handlers/presence.py4
-rw-r--r--synapse/handlers/profile.py16
-rw-r--r--synapse/handlers/room_member.py3
-rw-r--r--synapse/handlers/saml_handler.py55
-rw-r--r--synapse/handlers/set_password.py2
-rw-r--r--synapse/handlers/sync.py11
-rw-r--r--synapse/handlers/typing.py11
-rw-r--r--synapse/http/request_metrics.py6
-rw-r--r--synapse/http/site.py13
-rw-r--r--synapse/logging/_structured.py4
-rw-r--r--synapse/logging/context.py239
-rw-r--r--synapse/logging/scopecontextmanager.py13
-rw-r--r--synapse/replication/http/__init__.py2
-rw-r--r--synapse/replication/http/streams.py78
-rw-r--r--synapse/replication/slave/storage/_base.py14
-rw-r--r--synapse/replication/slave/storage/devices.py36
-rw-r--r--synapse/replication/slave/storage/pushers.py3
-rw-r--r--synapse/replication/tcp/client.py9
-rw-r--r--synapse/replication/tcp/commands.py70
-rw-r--r--synapse/replication/tcp/protocol.py215
-rw-r--r--synapse/replication/tcp/resource.py61
-rw-r--r--synapse/replication/tcp/streams/__init__.py76
-rw-r--r--synapse/replication/tcp/streams/_base.py368
-rw-r--r--synapse/replication/tcp/streams/events.py5
-rw-r--r--synapse/replication/tcp/streams/federation.py33
-rw-r--r--synapse/res/templates/sso_auth_confirm.html14
-rw-r--r--synapse/rest/__init__.py2
-rw-r--r--synapse/rest/admin/__init__.py7
-rw-r--r--synapse/rest/admin/rooms.py79
-rw-r--r--synapse/rest/client/v1/login.py174
-rw-r--r--synapse/rest/client/v2_alpha/account.py40
-rw-r--r--synapse/rest/client/v2_alpha/auth.py95
-rw-r--r--synapse/rest/client/v2_alpha/devices.py12
-rw-r--r--synapse/rest/client/v2_alpha/keys.py6
-rw-r--r--synapse/rest/client/v2_alpha/password_policy.py58
-rw-r--r--synapse/rest/client/v2_alpha/register.py8
-rw-r--r--synapse/rest/client/v2_alpha/room_keys.py2
-rw-r--r--synapse/rest/media/v1/download_resource.py3
-rw-r--r--synapse/rest/media/v1/media_repository.py110
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py37
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py54
-rw-r--r--synapse/server.py48
-rw-r--r--synapse/server.pyi2
-rw-r--r--synapse/storage/data_stores/main/__init__.py5
-rw-r--r--synapse/storage/data_stores/main/cache.py44
-rw-r--r--synapse/storage/data_stores/main/deviceinbox.py88
-rw-r--r--synapse/storage/data_stores/main/devices.py221
-rw-r--r--synapse/storage/data_stores/main/directory.py26
-rw-r--r--synapse/storage/data_stores/main/e2e_room_keys.py3
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py14
-rw-r--r--synapse/storage/data_stores/main/events.py114
-rw-r--r--synapse/storage/data_stores/main/events_worker.py118
-rw-r--r--synapse/storage/data_stores/main/media_repository.py4
-rw-r--r--synapse/storage/data_stores/main/presence.py23
-rw-r--r--synapse/storage/data_stores/main/push_rule.py1
-rw-r--r--synapse/storage/data_stores/main/room.py40
-rw-r--r--synapse/storage/database.py11
-rw-r--r--synapse/util/metrics.py4
-rw-r--r--synapse/util/patch_inline_callbacks.py36
-rw-r--r--synapse/util/stringutils.py21
82 files changed, 2399 insertions, 1419 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index cc8577552b..fda2c2e5bb 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -61,6 +61,7 @@ class LoginType(object):
     MSISDN = "m.login.msisdn"
     RECAPTCHA = "m.login.recaptcha"
     TERMS = "m.login.terms"
+    SSO = "org.matrix.login.sso"
     DUMMY = "m.login.dummy"
 
     # Only for C/S API v1
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 616942b057..11da016ac5 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -64,6 +64,13 @@ class Codes(object):
     INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
     WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION"
     EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT"
+    PASSWORD_TOO_SHORT = "M_PASSWORD_TOO_SHORT"
+    PASSWORD_NO_DIGIT = "M_PASSWORD_NO_DIGIT"
+    PASSWORD_NO_UPPERCASE = "M_PASSWORD_NO_UPPERCASE"
+    PASSWORD_NO_LOWERCASE = "M_PASSWORD_NO_LOWERCASE"
+    PASSWORD_NO_SYMBOL = "M_PASSWORD_NO_SYMBOL"
+    PASSWORD_IN_DICTIONARY = "M_PASSWORD_IN_DICTIONARY"
+    WEAK_PASSWORD = "M_WEAK_PASSWORD"
     INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
     USER_DEACTIVATED = "M_USER_DEACTIVATED"
     BAD_ALIAS = "M_BAD_ALIAS"
@@ -439,6 +446,20 @@ class IncompatibleRoomVersionError(SynapseError):
         return cs_error(self.msg, self.errcode, room_version=self._room_version)
 
 
+class PasswordRefusedError(SynapseError):
+    """A password has been refused, either during password reset/change or registration.
+    """
+
+    def __init__(
+        self,
+        msg="This password doesn't comply with the server's policy",
+        errcode=Codes.WEAK_PASSWORD,
+    ):
+        super(PasswordRefusedError, self).__init__(
+            code=400, msg=msg, errcode=errcode,
+        )
+
+
 class RequestSendFailed(RuntimeError):
     """Sending a HTTP request over federation failed due to not being able to
     talk to the remote server for some reason.
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 5363642d64..1ee266f7c5 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -65,12 +65,24 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
 from synapse.replication.slave.storage.room import RoomStore
 from synapse.replication.slave.storage.transactions import SlavedTransactionStore
 from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.replication.tcp.streams._base import (
+from synapse.replication.tcp.commands import ClearUserSyncsCommand
+from synapse.replication.tcp.streams import (
+    AccountDataStream,
     DeviceListsStream,
+    GroupServerStream,
+    PresenceStream,
+    PushersStream,
+    PushRulesStream,
     ReceiptsStream,
+    TagAccountDataStream,
     ToDeviceStream,
+    TypingStream,
+)
+from synapse.replication.tcp.streams.events import (
+    EventsStream,
+    EventsStreamEventRow,
+    EventsStreamRow,
 )
-from synapse.replication.tcp.streams.events import EventsStreamEventRow, EventsStreamRow
 from synapse.rest.admin import register_servlets_for_media_repo
 from synapse.rest.client.v1 import events
 from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
@@ -113,7 +125,6 @@ from synapse.types import ReadReceipt
 from synapse.util.async_helpers import Linearizer
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
-from synapse.util.stringutils import random_string
 from synapse.util.versionstring import get_version_string
 
 logger = logging.getLogger("synapse.app.generic_worker")
@@ -222,6 +233,7 @@ class GenericWorkerPresence(object):
         self.user_to_num_current_syncs = {}
         self.clock = hs.get_clock()
         self.notifier = hs.get_notifier()
+        self.instance_id = hs.get_instance_id()
 
         active_presence = self.store.take_presence_startup_info()
         self.user_to_current_state = {state.user_id: state for state in active_presence}
@@ -234,13 +246,24 @@ class GenericWorkerPresence(object):
             self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
         )
 
-        self.process_id = random_string(16)
-        logger.info("Presence process_id is %r", self.process_id)
+        hs.get_reactor().addSystemEventTrigger(
+            "before",
+            "shutdown",
+            run_as_background_process,
+            "generic_presence.on_shutdown",
+            self._on_shutdown,
+        )
+
+    def _on_shutdown(self):
+        if self.hs.config.use_presence:
+            self.hs.get_tcp_replication().send_command(
+                ClearUserSyncsCommand(self.instance_id)
+            )
 
     def send_user_sync(self, user_id, is_syncing, last_sync_ms):
         if self.hs.config.use_presence:
             self.hs.get_tcp_replication().send_user_sync(
-                user_id, is_syncing, last_sync_ms
+                self.instance_id, user_id, is_syncing, last_sync_ms
             )
 
     def mark_as_coming_online(self, user_id):
@@ -390,6 +413,9 @@ class GenericWorkerTyping(object):
             self._room_serials[row.room_id] = token
             self._room_typing[row.room_id] = row.user_ids
 
+    def get_current_token(self) -> int:
+        return self._latest_room_serial
+
 
 class GenericWorkerSlavedStore(
     # FIXME(#3714): We need to add UserDirectoryStore as we write directly
@@ -626,7 +652,7 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler):
             if self.send_handler:
                 self.send_handler.process_replication_rows(stream_name, token, rows)
 
-            if stream_name == "events":
+            if stream_name == EventsStream.NAME:
                 # We shouldn't get multiple rows per token for events stream, so
                 # we don't need to optimise this for multiple rows.
                 for row in rows:
@@ -649,43 +675,44 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler):
                     )
 
                 await self.pusher_pool.on_new_notifications(token, token)
-            elif stream_name == "push_rules":
+            elif stream_name == PushRulesStream.NAME:
                 self.notifier.on_new_event(
                     "push_rules_key", token, users=[row.user_id for row in rows]
                 )
-            elif stream_name in ("account_data", "tag_account_data"):
+            elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME):
                 self.notifier.on_new_event(
                     "account_data_key", token, users=[row.user_id for row in rows]
                 )
-            elif stream_name == "receipts":
+            elif stream_name == ReceiptsStream.NAME:
                 self.notifier.on_new_event(
                     "receipt_key", token, rooms=[row.room_id for row in rows]
                 )
                 await self.pusher_pool.on_new_receipts(
                     token, token, {row.room_id for row in rows}
                 )
-            elif stream_name == "typing":
+            elif stream_name == TypingStream.NAME:
                 self.typing_handler.process_replication_rows(token, rows)
                 self.notifier.on_new_event(
                     "typing_key", token, rooms=[row.room_id for row in rows]
                 )
-            elif stream_name == "to_device":
+            elif stream_name == ToDeviceStream.NAME:
                 entities = [row.entity for row in rows if row.entity.startswith("@")]
                 if entities:
                     self.notifier.on_new_event("to_device_key", token, users=entities)
-            elif stream_name == "device_lists":
+            elif stream_name == DeviceListsStream.NAME:
                 all_room_ids = set()
                 for row in rows:
-                    room_ids = await self.store.get_rooms_for_user(row.user_id)
-                    all_room_ids.update(room_ids)
+                    if row.entity.startswith("@"):
+                        room_ids = await self.store.get_rooms_for_user(row.entity)
+                        all_room_ids.update(room_ids)
                 self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
-            elif stream_name == "presence":
+            elif stream_name == PresenceStream.NAME:
                 await self.presence_handler.process_replication_rows(token, rows)
-            elif stream_name == "receipts":
+            elif stream_name == GroupServerStream.NAME:
                 self.notifier.on_new_event(
                     "groups_key", token, users=[row.user_id for row in rows]
                 )
-            elif stream_name == "pushers":
+            elif stream_name == PushersStream.NAME:
                 for row in rows:
                     if row.deleted:
                         self.stop_pusher(row.user_id, row.app_id, row.pushkey)
@@ -774,7 +801,10 @@ class FederationSenderHandler(object):
 
         # ... as well as device updates and messages
         elif stream_name == DeviceListsStream.NAME:
-            hosts = {row.destination for row in rows}
+            # The entities are either user IDs (starting with '@') whose devices
+            # have changed, or remote servers that we need to tell about
+            # changes.
+            hosts = {row.entity for row in rows if not row.entity.startswith("@")}
             for host in hosts:
                 self.federation_sender.send_device_messages(host)
 
@@ -789,7 +819,7 @@ class FederationSenderHandler(object):
     async def _on_new_receipts(self, rows):
         """
         Args:
-            rows (iterable[synapse.replication.tcp.streams.ReceiptsStreamRow]):
+            rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]):
                 new receipts to be processed
         """
         for receipt in rows:
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index ba846042c4..efe2af5504 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -294,7 +294,6 @@ class RootConfig(object):
         report_stats=None,
         open_private_ports=False,
         listeners=None,
-        database_conf=None,
         tls_certificate_path=None,
         tls_private_key_path=None,
         acme_domain=None,
@@ -367,7 +366,6 @@ class RootConfig(object):
                 report_stats=report_stats,
                 open_private_ports=open_private_ports,
                 listeners=listeners,
-                database_conf=database_conf,
                 tls_certificate_path=tls_certificate_path,
                 tls_private_key_path=tls_private_key_path,
                 acme_domain=acme_domain,
diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index f0171bb5b2..56c87fa296 100644
--- a/synapse/config/captcha.py
+++ b/synapse/config/captcha.py
@@ -24,7 +24,6 @@ class CaptchaConfig(Config):
         self.enable_registration_captcha = config.get(
             "enable_registration_captcha", False
         )
-        self.captcha_bypass_secret = config.get("captcha_bypass_secret")
         self.recaptcha_siteverify_api = config.get(
             "recaptcha_siteverify_api",
             "https://www.recaptcha.net/recaptcha/api/siteverify",
@@ -49,10 +48,6 @@ class CaptchaConfig(Config):
         #
         #enable_registration_captcha: false
 
-        # A secret key used to bypass the captcha test entirely.
-        #
-        #captcha_bypass_secret: "YOUR_SECRET_HERE"
-
         # The API endpoint to use for verifying m.login.recaptcha responses.
         #
         #recaptcha_siteverify_api: "https://www.recaptcha.net/recaptcha/api/siteverify"
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 219b32f670..c27fef157b 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,14 +15,65 @@
 # limitations under the License.
 import logging
 import os
-from textwrap import indent
-
-import yaml
 
 from synapse.config._base import Config, ConfigError
 
 logger = logging.getLogger(__name__)
 
+NON_SQLITE_DATABASE_PATH_WARNING = """\
+Ignoring 'database_path' setting: not using a sqlite3 database.
+--------------------------------------------------------------------------------
+"""
+
+DEFAULT_CONFIG = """\
+## Database ##
+
+# The 'database' setting defines the database that synapse uses to store all of
+# its data.
+#
+# 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or
+# 'psycopg2' (for PostgreSQL).
+#
+# 'args' gives options which are passed through to the database engine,
+# except for options starting 'cp_', which are used to configure the Twisted
+# connection pool. For a reference to valid arguments, see:
+#   * for sqlite: https://docs.python.org/3/library/sqlite3.html#sqlite3.connect
+#   * for postgres: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
+#   * for the connection pool: https://twistedmatrix.com/documents/current/api/twisted.enterprise.adbapi.ConnectionPool.html#__init__
+#
+#
+# Example SQLite configuration:
+#
+#database:
+#  name: sqlite3
+#  args:
+#    database: /path/to/homeserver.db
+#
+#
+# Example Postgres configuration:
+#
+#database:
+#  name: psycopg2
+#  args:
+#    user: synapse
+#    password: secretpassword
+#    database: synapse
+#    host: localhost
+#    cp_min: 5
+#    cp_max: 10
+#
+# For more information on using Synapse with Postgres, see `docs/postgres.md`.
+#
+database:
+  name: sqlite3
+  args:
+    database: %(database_path)s
+
+# Number of events to cache in memory.
+#
+#event_cache_size: 10K
+"""
+
 
 class DatabaseConnectionConfig:
     """Contains the connection config for a particular database.
@@ -36,10 +88,12 @@ class DatabaseConnectionConfig:
     """
 
     def __init__(self, name: str, db_config: dict):
-        if db_config["name"] not in ("sqlite3", "psycopg2"):
-            raise ConfigError("Unsupported database type %r" % (db_config["name"],))
+        db_engine = db_config.get("name", "sqlite3")
+
+        if db_engine not in ("sqlite3", "psycopg2"):
+            raise ConfigError("Unsupported database type %r" % (db_engine,))
 
-        if db_config["name"] == "sqlite3":
+        if db_engine == "sqlite3":
             db_config.setdefault("args", {}).update(
                 {"cp_min": 1, "cp_max": 1, "check_same_thread": False}
             )
@@ -56,6 +110,11 @@ class DatabaseConnectionConfig:
 class DatabaseConfig(Config):
     section = "database"
 
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self.databases = []
+
     def read_config(self, config, **kwargs):
         self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
 
@@ -76,12 +135,13 @@ class DatabaseConfig(Config):
 
         multi_database_config = config.get("databases")
         database_config = config.get("database")
+        database_path = config.get("database_path")
 
         if multi_database_config and database_config:
             raise ConfigError("Can't specify both 'database' and 'datbases' in config")
 
         if multi_database_config:
-            if config.get("database_path"):
+            if database_path:
                 raise ConfigError("Can't specify 'database_path' with 'databases'")
 
             self.databases = [
@@ -89,65 +149,55 @@ class DatabaseConfig(Config):
                 for name, db_conf in multi_database_config.items()
             ]
 
-        else:
-            if database_config is None:
-                database_config = {"name": "sqlite3", "args": {}}
-
+        if database_config:
             self.databases = [DatabaseConnectionConfig("master", database_config)]
 
-            self.set_databasepath(config.get("database_path"))
-
-    def generate_config_section(self, data_dir_path, database_conf, **kwargs):
-        if not database_conf:
-            database_path = os.path.join(data_dir_path, "homeserver.db")
-            database_conf = (
-                """# The database engine name
-          name: "sqlite3"
-          # Arguments to pass to the engine
-          args:
-            # Path to the database
-            database: "%(database_path)s"
-            """
-                % locals()
-            )
-        else:
-            database_conf = indent(yaml.dump(database_conf), " " * 10).lstrip()
+        if database_path:
+            if self.databases and self.databases[0].name != "sqlite3":
+                logger.warning(NON_SQLITE_DATABASE_PATH_WARNING)
+                return
 
-        return (
-            """\
-        ## Database ##
+            database_config = {"name": "sqlite3", "args": {}}
+            self.databases = [DatabaseConnectionConfig("master", database_config)]
+            self.set_databasepath(database_path)
 
-        database:
-          %(database_conf)s
-        # Number of events to cache in memory.
-        #
-        #event_cache_size: 10K
-        """
-            % locals()
-        )
+    def generate_config_section(self, data_dir_path, **kwargs):
+        return DEFAULT_CONFIG % {
+            "database_path": os.path.join(data_dir_path, "homeserver.db")
+        }
 
     def read_arguments(self, args):
-        self.set_databasepath(args.database_path)
+        """
+        Cases for the cli input:
+          - If no databases are configured and no database_path is set, raise.
+          - No databases and only database_path available ==> sqlite3 db.
+          - If there are multiple databases and a database_path raise an error.
+          - If the database set in the config file is sqlite then
+            overwrite with the command line argument.
+        """
 
-    def set_databasepath(self, database_path):
-        if database_path is None:
+        if args.database_path is None:
+            if not self.databases:
+                raise ConfigError("No database config provided")
             return
 
-        if database_path != ":memory:":
-            database_path = self.abspath(database_path)
+        if len(self.databases) == 0:
+            database_config = {"name": "sqlite3", "args": {}}
+            self.databases = [DatabaseConnectionConfig("master", database_config)]
+            self.set_databasepath(args.database_path)
+            return
+
+        if self.get_single_database().name == "sqlite3":
+            self.set_databasepath(args.database_path)
+        else:
+            logger.warning(NON_SQLITE_DATABASE_PATH_WARNING)
 
-        # We only support setting a database path if we have a single sqlite3
-        # database.
-        if len(self.databases) != 1:
-            raise ConfigError("Cannot specify 'database_path' with multiple databases")
+    def set_databasepath(self, database_path):
 
-        database = self.get_single_database()
-        if database.config["name"] != "sqlite3":
-            # We don't raise here as we haven't done so before for this case.
-            logger.warn("Ignoring 'database_path' for non-sqlite3 database")
-            return
+        if database_path != ":memory:":
+            database_path = self.abspath(database_path)
 
-        database.config["args"]["database"] = database_path
+        self.databases[0].config["args"]["database"] = database_path
 
     @staticmethod
     def add_arguments(parser):
@@ -162,7 +212,7 @@ class DatabaseConfig(Config):
     def get_single_database(self) -> DatabaseConnectionConfig:
         """Returns the database if there is only one, useful for e.g. tests
         """
-        if len(self.databases) != 1:
+        if not self.databases:
             raise Exception("More than one database exists")
 
         return self.databases[0]
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index 22538153e1..6f517a71d0 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -86,7 +86,7 @@ class MetricsConfig(Config):
         # enabled by default, either for performance reasons or limited use.
         #
         metrics_flags:
-            # Publish synapse_federation_known_servers, a g auge of the number of
+            # Publish synapse_federation_known_servers, a gauge of the number of
             # servers this homeserver knows about, including itself. May cause
             # performance problems on large homeservers.
             #
diff --git a/synapse/config/password.py b/synapse/config/password.py
index 2a634ac751..9c0ea8c30a 100644
--- a/synapse/config/password.py
+++ b/synapse/config/password.py
@@ -31,6 +31,10 @@ class PasswordConfig(Config):
         self.password_localdb_enabled = password_config.get("localdb_enabled", True)
         self.password_pepper = password_config.get("pepper", "")
 
+        # Password policy
+        self.password_policy = password_config.get("policy") or {}
+        self.password_policy_enabled = self.password_policy.get("enabled", False)
+
     def generate_config_section(self, config_dir_path, server_name, **kwargs):
         return """\
         password_config:
@@ -48,4 +52,39 @@ class PasswordConfig(Config):
            # DO NOT CHANGE THIS AFTER INITIAL SETUP!
            #
            #pepper: "EVEN_MORE_SECRET"
+
+           # Define and enforce a password policy. Each parameter is optional.
+           # This is an implementation of MSC2000.
+           #
+           policy:
+              # Whether to enforce the password policy.
+              # Defaults to 'false'.
+              #
+              #enabled: true
+
+              # Minimum accepted length for a password.
+              # Defaults to 0.
+              #
+              #minimum_length: 15
+
+              # Whether a password must contain at least one digit.
+              # Defaults to 'false'.
+              #
+              #require_digit: true
+
+              # Whether a password must contain at least one symbol.
+              # A symbol is any character that's not a number or a letter.
+              # Defaults to 'false'.
+              #
+              #require_symbol: true
+
+              # Whether a password must contain at least one lowercase letter.
+              # Defaults to 'false'.
+              #
+              #require_lowercase: true
+
+              # Whether a password must contain at least one lowercase letter.
+              # Defaults to 'false'.
+              #
+              #require_uppercase: true
         """
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 9bb3beedbc..e7ea3a01cb 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -129,6 +129,10 @@ class RegistrationConfig(Config):
                 raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,))
         self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
 
+        self.enable_set_displayname = config.get("enable_set_displayname", True)
+        self.enable_set_avatar_url = config.get("enable_set_avatar_url", True)
+        self.enable_3pid_changes = config.get("enable_3pid_changes", True)
+
         self.disable_msisdn_registration = config.get(
             "disable_msisdn_registration", False
         )
@@ -330,6 +334,29 @@ class RegistrationConfig(Config):
             #email: https://example.com     # Delegate email sending to example.com
             #msisdn: http://localhost:8090  # Delegate SMS sending to this local process
 
+        # Whether users are allowed to change their displayname after it has
+        # been initially set. Useful when provisioning users based on the
+        # contents of a third-party directory.
+        #
+        # Does not apply to server administrators. Defaults to 'true'
+        #
+        #enable_set_displayname: false
+
+        # Whether users are allowed to change their avatar after it has been
+        # initially set. Useful when provisioning users based on the contents
+        # of a third-party directory.
+        #
+        # Does not apply to server administrators. Defaults to 'true'
+        #
+        #enable_set_avatar_url: false
+
+        # Whether users can change the 3PIDs associated with their accounts
+        # (email address and msisdn).
+        #
+        # Defaults to 'true'
+        #
+        #enable_3pid_changes: false
+
         # Users who register on this homeserver will automatically be joined
         # to these rooms
         #
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 983f0ead8c..a9f4025bfe 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -43,8 +43,8 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.logging.context import (
-    LoggingContext,
     PreserveLoggingContext,
+    current_context,
     make_deferred_yieldable,
     preserve_fn,
     run_in_background,
@@ -236,7 +236,7 @@ class Keyring(object):
         """
 
         try:
-            ctx = LoggingContext.current_context()
+            ctx = current_context()
 
             # map from server name to a set of outstanding request ids
             server_to_request_ids = {}
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 5c991e5412..4b115aac04 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -25,19 +25,15 @@ from twisted.python.failure import Failure
 
 from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
 from synapse.api.errors import Codes, SynapseError
-from synapse.api.room_versions import (
-    KNOWN_ROOM_VERSIONS,
-    EventFormatVersions,
-    RoomVersion,
-)
+from synapse.api.room_versions import EventFormatVersions, RoomVersion
 from synapse.crypto.event_signing import check_event_content_hash
 from synapse.crypto.keyring import Keyring
 from synapse.events import EventBase, make_event_from_dict
 from synapse.events.utils import prune_event
 from synapse.http.servlet import assert_params_in_dict
 from synapse.logging.context import (
-    LoggingContext,
     PreserveLoggingContext,
+    current_context,
     make_deferred_yieldable,
 )
 from synapse.types import JsonDict, get_domain_from_id
@@ -55,13 +51,15 @@ class FederationBase(object):
         self.store = hs.get_datastore()
         self._clock = hs.get_clock()
 
-    def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred:
+    def _check_sigs_and_hash(
+        self, room_version: RoomVersion, pdu: EventBase
+    ) -> Deferred:
         return make_deferred_yieldable(
             self._check_sigs_and_hashes(room_version, [pdu])[0]
         )
 
     def _check_sigs_and_hashes(
-        self, room_version: str, pdus: List[EventBase]
+        self, room_version: RoomVersion, pdus: List[EventBase]
     ) -> List[Deferred]:
         """Checks that each of the received events is correctly signed by the
         sending server.
@@ -80,7 +78,7 @@ class FederationBase(object):
         """
         deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
 
-        ctx = LoggingContext.current_context()
+        ctx = current_context()
 
         def callback(_, pdu: EventBase):
             with PreserveLoggingContext(ctx):
@@ -146,7 +144,7 @@ class PduToCheckSig(
 
 
 def _check_sigs_on_pdus(
-    keyring: Keyring, room_version: str, pdus: Iterable[EventBase]
+    keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase]
 ) -> List[Deferred]:
     """Check that the given events are correctly signed
 
@@ -191,10 +189,6 @@ def _check_sigs_on_pdus(
         for p in pdus
     ]
 
-    v = KNOWN_ROOM_VERSIONS.get(room_version)
-    if not v:
-        raise RuntimeError("Unrecognized room version %s" % (room_version,))
-
     # First we check that the sender event is signed by the sender's domain
     # (except if its a 3pid invite, in which case it may be sent by any server)
     pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
@@ -204,7 +198,7 @@ def _check_sigs_on_pdus(
             (
                 p.sender_domain,
                 p.redacted_pdu_json,
-                p.pdu.origin_server_ts if v.enforce_key_validity else 0,
+                p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
                 p.pdu.event_id,
             )
             for p in pdus_to_check_sender
@@ -227,7 +221,7 @@ def _check_sigs_on_pdus(
     # event id's domain (normally only the case for joins/leaves), and add additional
     # checks. Only do this if the room version has a concept of event ID domain
     # (ie, the room version uses old-style non-hash event IDs).
-    if v.event_format == EventFormatVersions.V1:
+    if room_version.event_format == EventFormatVersions.V1:
         pdus_to_check_event_id = [
             p
             for p in pdus_to_check
@@ -239,7 +233,7 @@ def _check_sigs_on_pdus(
                 (
                     get_domain_from_id(p.pdu.event_id),
                     p.redacted_pdu_json,
-                    p.pdu.origin_server_ts if v.enforce_key_validity else 0,
+                    p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
                     p.pdu.event_id,
                 )
                 for p in pdus_to_check_event_id
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 8c6b839478..a0071fec94 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -220,8 +220,7 @@ class FederationClient(FederationBase):
         # FIXME: We should handle signature failures more gracefully.
         pdus[:] = await make_deferred_yieldable(
             defer.gatherResults(
-                self._check_sigs_and_hashes(room_version.identifier, pdus),
-                consumeErrors=True,
+                self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True,
             ).addErrback(unwrapFirstError)
         )
 
@@ -291,9 +290,7 @@ class FederationClient(FederationBase):
                     pdu = pdu_list[0]
 
                     # Check signatures are correct.
-                    signed_pdu = await self._check_sigs_and_hash(
-                        room_version.identifier, pdu
-                    )
+                    signed_pdu = await self._check_sigs_and_hash(room_version, pdu)
 
                     break
 
@@ -350,7 +347,7 @@ class FederationClient(FederationBase):
         self,
         origin: str,
         pdus: List[EventBase],
-        room_version: str,
+        room_version: RoomVersion,
         outlier: bool = False,
         include_none: bool = False,
     ) -> List[EventBase]:
@@ -396,7 +393,7 @@ class FederationClient(FederationBase):
                         self.get_pdu(
                             destinations=[pdu.origin],
                             event_id=pdu.event_id,
-                            room_version=room_version,  # type: ignore
+                            room_version=room_version,
                             outlier=outlier,
                             timeout=10000,
                         )
@@ -434,7 +431,7 @@ class FederationClient(FederationBase):
         ]
 
         signed_auth = await self._check_sigs_and_hash_and_fetch(
-            destination, auth_chain, outlier=True, room_version=room_version.identifier
+            destination, auth_chain, outlier=True, room_version=room_version
         )
 
         signed_auth.sort(key=lambda e: e.depth)
@@ -661,7 +658,7 @@ class FederationClient(FederationBase):
                 destination,
                 list(pdus.values()),
                 outlier=True,
-                room_version=room_version.identifier,
+                room_version=room_version,
             )
 
             valid_pdus_map = {p.event_id: p for p in valid_pdus}
@@ -756,7 +753,7 @@ class FederationClient(FederationBase):
         pdu = event_from_pdu_json(pdu_dict, room_version)
 
         # Check signatures are correct.
-        pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
+        pdu = await self._check_sigs_and_hash(room_version, pdu)
 
         # FIXME: We should handle signature failures more gracefully.
 
@@ -948,7 +945,7 @@ class FederationClient(FederationBase):
             ]
 
             signed_events = await self._check_sigs_and_hash_and_fetch(
-                destination, events, outlier=False, room_version=room_version.identifier
+                destination, events, outlier=False, room_version=room_version
             )
         except HttpResponseException as e:
             if not e.code == 400:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 275b9c99d7..89d521bc31 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -409,7 +409,7 @@ class FederationServer(FederationBase):
         pdu = event_from_pdu_json(content, room_version)
         origin_host, _ = parse_server_name(origin)
         await self.check_server_matches_acl(origin_host, pdu.room_id)
-        pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
+        pdu = await self._check_sigs_and_hash(room_version, pdu)
         ret_pdu = await self.handler.on_invite_request(origin, pdu, room_version)
         time_now = self._clock.time_msec()
         return {"event": ret_pdu.get_pdu_json(time_now)}
@@ -425,7 +425,7 @@ class FederationServer(FederationBase):
 
         logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
 
-        pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
+        pdu = await self._check_sigs_and_hash(room_version, pdu)
 
         res_pdus = await self.handler.on_send_join_request(origin, pdu)
         time_now = self._clock.time_msec()
@@ -455,7 +455,7 @@ class FederationServer(FederationBase):
 
         logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
 
-        pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
+        pdu = await self._check_sigs_and_hash(room_version, pdu)
 
         await self.handler.on_send_leave_request(origin, pdu)
         return {}
@@ -611,7 +611,7 @@ class FederationServer(FederationBase):
                 logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
 
         # We've already checked that we know the room version by this point
-        room_version = await self.store.get_room_version_id(pdu.room_id)
+        room_version = await self.store.get_room_version(pdu.room_id)
 
         # Check signature.
         try:
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 876fb0e245..e1700ca8aa 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -477,7 +477,7 @@ def process_rows_for_federation(transaction_queue, rows):
 
     Args:
         transaction_queue (FederationSender)
-        rows (list(synapse.replication.tcp.streams.FederationStreamRow))
+        rows (list(synapse.replication.tcp.streams.federation.FederationStream.FederationStreamRow))
     """
 
     # The federation stream contains a bunch of different types of
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 233cb33daf..a477578e44 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -499,4 +499,13 @@ class FederationSender(object):
         self._get_per_destination_queue(destination).attempt_new_transaction()
 
     def get_current_token(self) -> int:
+        # Dummy implementation for case where federation sender isn't offloaded
+        # to a worker.
         return 0
+
+    async def get_replication_rows(
+        self, from_token, to_token, limit, federation_ack=None
+    ):
+        # Dummy implementation for case where federation sender isn't offloaded
+        # to a worker.
+        return []
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 7860f9625e..7c09d15a72 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -53,6 +53,31 @@ from ._base import BaseHandler
 logger = logging.getLogger(__name__)
 
 
+SUCCESS_TEMPLATE = """
+<html>
+<head>
+<title>Success!</title>
+<meta name='viewport' content='width=device-width, initial-scale=1,
+    user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
+<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
+<script>
+if (window.onAuthDone) {
+    window.onAuthDone();
+} else if (window.opener && window.opener.postMessage) {
+     window.opener.postMessage("authDone", "*");
+}
+</script>
+</head>
+<body>
+    <div>
+        <p>Thank you</p>
+        <p>You may now close this window and return to the application</p>
+    </div>
+</body>
+</html>
+"""
+
+
 class AuthHandler(BaseHandler):
     SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
 
@@ -91,6 +116,7 @@ class AuthHandler(BaseHandler):
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
         self.macaroon_gen = hs.get_macaroon_generator()
         self._password_enabled = hs.config.password_enabled
+        self._saml2_enabled = hs.config.saml2_enabled
 
         # we keep this as a list despite the O(N^2) implication so that we can
         # keep PASSWORD first and avoid confusing clients which pick the first
@@ -106,6 +132,13 @@ class AuthHandler(BaseHandler):
                     if t not in login_types:
                         login_types.append(t)
         self._supported_login_types = login_types
+        # Login types and UI Auth types have a heavy overlap, but are not
+        # necessarily identical. Login types have SSO (and other login types)
+        # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
+        ui_auth_types = login_types.copy()
+        if self._saml2_enabled:
+            ui_auth_types.append(LoginType.SSO)
+        self._supported_ui_auth_types = ui_auth_types
 
         # Ratelimiter for failed auth during UIA. Uses same ratelimit config
         # as per `rc_login.failed_attempts`.
@@ -113,10 +146,21 @@ class AuthHandler(BaseHandler):
 
         self._clock = self.hs.get_clock()
 
-        # Load the SSO redirect confirmation page HTML template
+        # Load the SSO HTML templates.
+
+        # The following template is shown to the user during a client login via SSO,
+        # after the SSO completes and before redirecting them back to their client.
+        # It notifies the user they are about to give access to their matrix account
+        # to the client.
         self._sso_redirect_confirm_template = load_jinja2_templates(
             hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
         )[0]
+        # The following template is shown during user interactive authentication
+        # in the fallback auth scenario. It notifies the user that they are
+        # authenticating for an operation to occur on their account.
+        self._sso_auth_confirm_template = load_jinja2_templates(
+            hs.config.sso_redirect_confirm_template_dir, ["sso_auth_confirm.html"],
+        )[0]
 
         self._server_name = hs.config.server_name
 
@@ -125,7 +169,12 @@ class AuthHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def validate_user_via_ui_auth(
-        self, requester: Requester, request_body: Dict[str, Any], clientip: str
+        self,
+        requester: Requester,
+        request: SynapseRequest,
+        request_body: Dict[str, Any],
+        clientip: str,
+        description: str,
     ):
         """
         Checks that the user is who they claim to be, via a UI auth.
@@ -137,10 +186,15 @@ class AuthHandler(BaseHandler):
         Args:
             requester: The user, as given by the access token
 
+            request: The request sent by the client.
+
             request_body: The body of the request sent by the client
 
             clientip: The IP address of the client.
 
+            description: A human readable string to be displayed to the user that
+                         describes the operation happening on their account.
+
         Returns:
             defer.Deferred[dict]: the parameters for this request (which may
                 have been given only in a previous call).
@@ -169,10 +223,12 @@ class AuthHandler(BaseHandler):
         )
 
         # build a list of supported flows
-        flows = [[login_type] for login_type in self._supported_login_types]
+        flows = [[login_type] for login_type in self._supported_ui_auth_types]
 
         try:
-            result, params, _ = yield self.check_auth(flows, request_body, clientip)
+            result, params, _ = yield self.check_auth(
+                flows, request, request_body, clientip, description
+            )
         except LoginError:
             # Update the ratelimite to say we failed (`can_do_action` doesn't raise).
             self._failed_uia_attempts_ratelimiter.can_do_action(
@@ -185,7 +241,7 @@ class AuthHandler(BaseHandler):
             raise
 
         # find the completed login type
-        for login_type in self._supported_login_types:
+        for login_type in self._supported_ui_auth_types:
             if login_type not in result:
                 continue
 
@@ -211,7 +267,12 @@ class AuthHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def check_auth(
-        self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str
+        self,
+        flows: List[List[str]],
+        request: SynapseRequest,
+        clientdict: Dict[str, Any],
+        clientip: str,
+        description: str,
     ):
         """
         Takes a dictionary sent by the client in the login / registration
@@ -231,11 +292,16 @@ class AuthHandler(BaseHandler):
                    strings representing auth-types. At least one full
                    flow must be completed in order for auth to be successful.
 
+            request: The request sent by the client.
+
             clientdict: The dictionary from the client root level, not the
                         'auth' key: this method prompts for auth if none is sent.
 
             clientip: The IP address of the client.
 
+            description: A human readable string to be displayed to the user that
+                         describes the operation happening on their account.
+
         Returns:
             defer.Deferred[dict, dict, str]: a deferred tuple of
                 (creds, params, session_id).
@@ -270,13 +336,33 @@ class AuthHandler(BaseHandler):
             # email auth link on there). It's probably too open to abuse
             # because it lets unauthenticated clients store arbitrary objects
             # on a homeserver.
-            # Revisit: Assumimg the REST APIs do sensible validation, the data
+            # Revisit: Assuming the REST APIs do sensible validation, the data
             # isn't arbintrary.
             session["clientdict"] = clientdict
             self._save_session(session)
         elif "clientdict" in session:
             clientdict = session["clientdict"]
 
+        # Ensure that the queried operation does not vary between stages of
+        # the UI authentication session. This is done by generating a stable
+        # comparator based on the URI, method, and body (minus the auth dict)
+        # and storing it during the initial query. Subsequent queries ensure
+        # that this comparator has not changed.
+        comparator = (request.uri, request.method, clientdict)
+        if "ui_auth" not in session:
+            session["ui_auth"] = comparator
+            self._save_session(session)
+        elif session["ui_auth"] != comparator:
+            raise SynapseError(
+                403,
+                "Requested operation has changed during the UI authentication session.",
+            )
+
+        # Add a human readable description to the session.
+        if "description" not in session:
+            session["description"] = description
+            self._save_session(session)
+
         if not authdict:
             raise InteractiveAuthIncompleteError(
                 self._auth_dict_for_flows(flows, session)
@@ -322,6 +408,7 @@ class AuthHandler(BaseHandler):
                     creds,
                     list(clientdict),
                 )
+
                 return creds, clientdict, session["id"]
 
         ret = self._auth_dict_for_flows(flows, session)
@@ -962,6 +1049,56 @@ class AuthHandler(BaseHandler):
         else:
             return defer.succeed(False)
 
+    def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
+        """
+        Get the HTML for the SSO redirect confirmation page.
+
+        Args:
+            redirect_url: The URL to redirect to the SSO provider.
+            session_id: The user interactive authentication session ID.
+
+        Returns:
+            The HTML to render.
+        """
+        session = self._get_session_info(session_id)
+        # Get the human readable operation of what is occurring, falling back to
+        # a generic message if it isn't available for some reason.
+        description = session.get("description", "modify your account")
+        return self._sso_auth_confirm_template.render(
+            description=description, redirect_url=redirect_url,
+        )
+
+    def complete_sso_ui_auth(
+        self, registered_user_id: str, session_id: str, request: SynapseRequest,
+    ):
+        """Having figured out a mxid for this user, complete the HTTP request
+
+        Args:
+            registered_user_id: The registered user ID to complete SSO login for.
+            request: The request to complete.
+            client_redirect_url: The URL to which to redirect the user at the end of the
+                process.
+        """
+        # Mark the stage of the authentication as successful.
+        sess = self._get_session_info(session_id)
+        if "creds" not in sess:
+            sess["creds"] = {}
+        creds = sess["creds"]
+
+        # Save the user who authenticated with SSO, this will be used to ensure
+        # that the account be modified is also the person who logged in.
+        creds[LoginType.SSO] = registered_user_id
+        self._save_session(sess)
+
+        # Render the HTML and return.
+        html_bytes = SUCCESS_TEMPLATE.encode("utf8")
+        request.setResponseCode(200)
+        request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+        request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
+
+        request.write(html_bytes)
+        finish_request(request)
+
     def complete_sso_login(
         self,
         registered_user_id: str,
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
new file mode 100644
index 0000000000..f8dc274b78
--- /dev/null
+++ b/synapse/handlers/cas_handler.py
@@ -0,0 +1,204 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import xml.etree.ElementTree as ET
+from typing import AnyStr, Dict, Optional, Tuple
+
+from six.moves import urllib
+
+from twisted.web.client import PartialDownloadError
+
+from synapse.api.errors import Codes, LoginError
+from synapse.http.site import SynapseRequest
+from synapse.types import UserID, map_username_to_mxid_localpart
+
+logger = logging.getLogger(__name__)
+
+
+class CasHandler:
+    """
+    Utility class for to handle the response from a CAS SSO service.
+
+    Args:
+        hs (synapse.server.HomeServer)
+    """
+
+    def __init__(self, hs):
+        self._hostname = hs.hostname
+        self._auth_handler = hs.get_auth_handler()
+        self._registration_handler = hs.get_registration_handler()
+
+        self._cas_server_url = hs.config.cas_server_url
+        self._cas_service_url = hs.config.cas_service_url
+        self._cas_displayname_attribute = hs.config.cas_displayname_attribute
+        self._cas_required_attributes = hs.config.cas_required_attributes
+
+        self._http_client = hs.get_proxied_http_client()
+
+    def _build_service_param(self, client_redirect_url: AnyStr) -> str:
+        return "%s%s?%s" % (
+            self._cas_service_url,
+            "/_matrix/client/r0/login/cas/ticket",
+            urllib.parse.urlencode({"redirectUrl": client_redirect_url}),
+        )
+
+    async def _handle_cas_response(
+        self, request: SynapseRequest, cas_response_body: str, client_redirect_url: str
+    ) -> None:
+        """
+        Retrieves the user and display name from the CAS response and continues with the authentication.
+
+        Args:
+            request: The original client request.
+            cas_response_body: The response from the CAS server.
+            client_redirect_url: The URl to redirect the client to when
+                everything is done.
+        """
+        user, attributes = self._parse_cas_response(cas_response_body)
+        displayname = attributes.pop(self._cas_displayname_attribute, None)
+
+        for required_attribute, required_value in self._cas_required_attributes.items():
+            # If required attribute was not in CAS Response - Forbidden
+            if required_attribute not in attributes:
+                raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+
+            # Also need to check value
+            if required_value is not None:
+                actual_value = attributes[required_attribute]
+                # If required attribute value does not match expected - Forbidden
+                if required_value != actual_value:
+                    raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+
+        await self._on_successful_auth(user, request, client_redirect_url, displayname)
+
+    def _parse_cas_response(
+        self, cas_response_body: str
+    ) -> Tuple[str, Dict[str, Optional[str]]]:
+        """
+        Retrieve the user and other parameters from the CAS response.
+
+        Args:
+            cas_response_body: The response from the CAS query.
+
+        Returns:
+            A tuple of the user and a mapping of other attributes.
+        """
+        user = None
+        attributes = {}
+        try:
+            root = ET.fromstring(cas_response_body)
+            if not root.tag.endswith("serviceResponse"):
+                raise Exception("root of CAS response is not serviceResponse")
+            success = root[0].tag.endswith("authenticationSuccess")
+            for child in root[0]:
+                if child.tag.endswith("user"):
+                    user = child.text
+                if child.tag.endswith("attributes"):
+                    for attribute in child:
+                        # ElementTree library expands the namespace in
+                        # attribute tags to the full URL of the namespace.
+                        # We don't care about namespace here and it will always
+                        # be encased in curly braces, so we remove them.
+                        tag = attribute.tag
+                        if "}" in tag:
+                            tag = tag.split("}")[1]
+                        attributes[tag] = attribute.text
+            if user is None:
+                raise Exception("CAS response does not contain user")
+        except Exception:
+            logger.exception("Error parsing CAS response")
+            raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
+        if not success:
+            raise LoginError(
+                401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
+            )
+        return user, attributes
+
+    async def _on_successful_auth(
+        self,
+        username: str,
+        request: SynapseRequest,
+        client_redirect_url: str,
+        user_display_name: Optional[str] = None,
+    ) -> None:
+        """Called once the user has successfully authenticated with the SSO.
+
+        Registers the user if necessary, and then returns a redirect (with
+        a login token) to the client.
+
+        Args:
+            username: the remote user id. We'll map this onto
+                something sane for a MXID localpath.
+
+            request: the incoming request from the browser. We'll
+                respond to it with a redirect.
+
+            client_redirect_url: the redirect_url the client gave us when
+                it first started the process.
+
+            user_display_name: if set, and we have to register a new user,
+                we will set their displayname to this.
+        """
+        localpart = map_username_to_mxid_localpart(username)
+        user_id = UserID(localpart, self._hostname).to_string()
+        registered_user_id = await self._auth_handler.check_user_exists(user_id)
+        if not registered_user_id:
+            registered_user_id = await self._registration_handler.register_user(
+                localpart=localpart, default_display_name=user_display_name
+            )
+
+        self._auth_handler.complete_sso_login(
+            registered_user_id, request, client_redirect_url
+        )
+
+    def handle_redirect_request(self, client_redirect_url: bytes) -> bytes:
+        """
+        Generates a URL to the CAS server where the client should be redirected.
+
+        Args:
+            client_redirect_url: The final URL the client should go to after the
+                user has negotiated SSO.
+
+        Returns:
+            The URL to redirect to.
+        """
+        args = urllib.parse.urlencode(
+            {"service": self._build_service_param(client_redirect_url)}
+        )
+
+        return ("%s/login?%s" % (self._cas_server_url, args)).encode("ascii")
+
+    async def handle_ticket_request(
+        self, request: SynapseRequest, client_redirect_url: str, ticket: str
+    ) -> None:
+        """
+        Validates a CAS ticket sent by the client for login/registration.
+
+        On a successful request, writes a redirect to the request.
+        """
+        uri = self._cas_server_url + "/proxyValidate"
+        args = {
+            "ticket": ticket,
+            "service": self._build_service_param(client_redirect_url),
+        }
+        try:
+            body = await self._http_client.get_raw(uri, args)
+        except PartialDownloadError as pde:
+            # Twisted raises this error if the connection is closed,
+            # even if that's being used old-http style to signal end-of-data
+            body = pde.response
+
+        await self._handle_cas_response(request, body, client_redirect_url)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index a514c30714..993499f446 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -125,8 +125,14 @@ class DeviceWorkerHandler(BaseHandler):
         users_who_share_room = yield self.store.get_users_who_share_room_with_user(
             user_id
         )
+
+        tracked_users = set(users_who_share_room)
+
+        # Always tell the user about their own devices
+        tracked_users.add(user_id)
+
         changed = yield self.store.get_users_whose_devices_changed(
-            from_token.device_list_key, users_who_share_room
+            from_token.device_list_key, tracked_users
         )
 
         # Then work out if any users have since joined
@@ -456,7 +462,11 @@ class DeviceHandler(DeviceWorkerHandler):
 
         room_ids = yield self.store.get_rooms_for_user(user_id)
 
-        yield self.notifier.on_new_event("device_list_key", position, rooms=room_ids)
+        # specify the user ID too since the user should always get their own device list
+        # updates, even if they aren't in any rooms.
+        yield self.notifier.on_new_event(
+            "device_list_key", position, users=[user_id], rooms=room_ids
+        )
 
         if hosts:
             logger.info(
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 1d842c369b..53e5f585d9 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -127,7 +127,11 @@ class DirectoryHandler(BaseHandler):
                     errcode=Codes.EXCLUSIVE,
                 )
         else:
-            if self.require_membership and check_membership:
+            # Server admins are not subject to the same constraints as normal
+            # users when creating an alias (e.g. being in the room).
+            is_admin = yield self.auth.is_server_admin(requester.user)
+
+            if (self.require_membership and check_membership) and not is_admin:
                 rooms_for_user = yield self.store.get_rooms_for_user(user_id)
                 if room_id not in rooms_for_user:
                     raise AuthError(
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 38ab6a8fc3..c7aa7acf3b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -49,6 +49,7 @@ from synapse.event_auth import auth_types_for_event
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.events.validator import EventValidator
+from synapse.handlers._base import BaseHandler
 from synapse.logging.context import (
     make_deferred_yieldable,
     nested_logging_context,
@@ -69,10 +70,9 @@ from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.distributor import user_joined_room
 from synapse.util.retryutils import NotRetryingDestination
+from synapse.util.stringutils import shortstr
 from synapse.visibility import filter_events_for_server
 
-from ._base import BaseHandler
-
 logger = logging.getLogger(__name__)
 
 
@@ -93,27 +93,6 @@ class _NewEventInfo:
     auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
 
 
-def shortstr(iterable, maxitems=5):
-    """If iterable has maxitems or fewer, return the stringification of a list
-    containing those items.
-
-    Otherwise, return the stringification of a a list with the first maxitems items,
-    followed by "...".
-
-    Args:
-        iterable (Iterable): iterable to truncate
-        maxitems (int): number of items to return before truncating
-
-    Returns:
-        unicode
-    """
-
-    items = list(itertools.islice(iterable, maxitems + 1))
-    if len(items) <= maxitems:
-        return str(items)
-    return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
-
-
 class FederationHandler(BaseHandler):
     """Handles events that originated from federation.
         Responsible for:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index b743fc2dcc..522271eed1 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -852,6 +852,38 @@ class EventCreationHandler(object):
                 )
 
     @defer.inlineCallbacks
+    def _validate_canonical_alias(
+        self, directory_handler, room_alias_str, expected_room_id
+    ):
+        """
+        Ensure that the given room alias points to the expected room ID.
+
+        Args:
+            directory_handler: The directory handler object.
+            room_alias_str: The room alias to check.
+            expected_room_id: The room ID that the alias should point to.
+        """
+        room_alias = RoomAlias.from_string(room_alias_str)
+        try:
+            mapping = yield directory_handler.get_association(room_alias)
+        except SynapseError as e:
+            # Turn M_NOT_FOUND errors into M_BAD_ALIAS errors.
+            if e.errcode == Codes.NOT_FOUND:
+                raise SynapseError(
+                    400,
+                    "Room alias %s does not point to the room" % (room_alias_str,),
+                    Codes.BAD_ALIAS,
+                )
+            raise
+
+        if mapping["room_id"] != expected_room_id:
+            raise SynapseError(
+                400,
+                "Room alias %s does not point to the room" % (room_alias_str,),
+                Codes.BAD_ALIAS,
+            )
+
+    @defer.inlineCallbacks
     def persist_and_notify_client_event(
         self, requester, event, context, ratelimit=True, extra_users=[]
     ):
@@ -905,15 +937,9 @@ class EventCreationHandler(object):
             room_alias_str = event.content.get("alias", None)
             directory_handler = self.hs.get_handlers().directory_handler
             if room_alias_str and room_alias_str != original_alias:
-                room_alias = RoomAlias.from_string(room_alias_str)
-                mapping = yield directory_handler.get_association(room_alias)
-
-                if mapping["room_id"] != event.room_id:
-                    raise SynapseError(
-                        400,
-                        "Room alias %s does not point to the room" % (room_alias_str,),
-                        Codes.BAD_ALIAS,
-                    )
+                yield self._validate_canonical_alias(
+                    directory_handler, room_alias_str, event.room_id
+                )
 
             # Check that alt_aliases is the proper form.
             alt_aliases = event.content.get("alt_aliases", [])
@@ -931,16 +957,9 @@ class EventCreationHandler(object):
             new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
             if new_alt_aliases:
                 for alias_str in new_alt_aliases:
-                    room_alias = RoomAlias.from_string(alias_str)
-                    mapping = yield directory_handler.get_association(room_alias)
-
-                    if mapping["room_id"] != event.room_id:
-                        raise SynapseError(
-                            400,
-                            "Room alias %s does not point to the room"
-                            % (room_alias_str,),
-                            Codes.BAD_ALIAS,
-                        )
+                    yield self._validate_canonical_alias(
+                        directory_handler, alias_str, event.room_id
+                    )
 
         federation_handler = self.hs.get_handlers().federation_handler
 
diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py
new file mode 100644
index 0000000000..d06b110269
--- /dev/null
+++ b/synapse/handlers/password_policy.py
@@ -0,0 +1,93 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+from synapse.api.errors import Codes, PasswordRefusedError
+
+logger = logging.getLogger(__name__)
+
+
+class PasswordPolicyHandler(object):
+    def __init__(self, hs):
+        self.policy = hs.config.password_policy
+        self.enabled = hs.config.password_policy_enabled
+
+        # Regexps for the spec'd policy parameters.
+        self.regexp_digit = re.compile("[0-9]")
+        self.regexp_symbol = re.compile("[^a-zA-Z0-9]")
+        self.regexp_uppercase = re.compile("[A-Z]")
+        self.regexp_lowercase = re.compile("[a-z]")
+
+    def validate_password(self, password):
+        """Checks whether a given password complies with the server's policy.
+
+        Args:
+            password (str): The password to check against the server's policy.
+
+        Raises:
+            PasswordRefusedError: The password doesn't comply with the server's policy.
+        """
+
+        if not self.enabled:
+            return
+
+        minimum_accepted_length = self.policy.get("minimum_length", 0)
+        if len(password) < minimum_accepted_length:
+            raise PasswordRefusedError(
+                msg=(
+                    "The password must be at least %d characters long"
+                    % minimum_accepted_length
+                ),
+                errcode=Codes.PASSWORD_TOO_SHORT,
+            )
+
+        if (
+            self.policy.get("require_digit", False)
+            and self.regexp_digit.search(password) is None
+        ):
+            raise PasswordRefusedError(
+                msg="The password must include at least one digit",
+                errcode=Codes.PASSWORD_NO_DIGIT,
+            )
+
+        if (
+            self.policy.get("require_symbol", False)
+            and self.regexp_symbol.search(password) is None
+        ):
+            raise PasswordRefusedError(
+                msg="The password must include at least one symbol",
+                errcode=Codes.PASSWORD_NO_SYMBOL,
+            )
+
+        if (
+            self.policy.get("require_uppercase", False)
+            and self.regexp_uppercase.search(password) is None
+        ):
+            raise PasswordRefusedError(
+                msg="The password must include at least one uppercase letter",
+                errcode=Codes.PASSWORD_NO_UPPERCASE,
+            )
+
+        if (
+            self.policy.get("require_lowercase", False)
+            and self.regexp_lowercase.search(password) is None
+        ):
+            raise PasswordRefusedError(
+                msg="The password must include at least one lowercase letter",
+                errcode=Codes.PASSWORD_NO_LOWERCASE,
+            )
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 5526015ddb..6912165622 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -747,7 +747,7 @@ class PresenceHandler(object):
 
         return False
 
-    async def get_all_presence_updates(self, last_id, current_id):
+    async def get_all_presence_updates(self, last_id, current_id, limit):
         """
         Gets a list of presence update rows from between the given stream ids.
         Each row has:
@@ -762,7 +762,7 @@ class PresenceHandler(object):
         """
         # TODO(markjh): replicate the unpersisted changes.
         # This could use the in-memory stores for recent changes.
-        rows = await self.store.get_all_presence_updates(last_id, current_id)
+        rows = await self.store.get_all_presence_updates(last_id, current_id, limit)
         return rows
 
     def notify_new_event(self):
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 50ce0c585b..6aa1c0f5e0 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -157,6 +157,15 @@ class BaseProfileHandler(BaseHandler):
         if not by_admin and target_user != requester.user:
             raise AuthError(400, "Cannot set another user's displayname")
 
+        if not by_admin and not self.hs.config.enable_set_displayname:
+            profile = yield self.store.get_profileinfo(target_user.localpart)
+            if profile.display_name:
+                raise SynapseError(
+                    400,
+                    "Changing display name is disabled on this server",
+                    Codes.FORBIDDEN,
+                )
+
         if len(new_displayname) > MAX_DISPLAYNAME_LEN:
             raise SynapseError(
                 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
@@ -218,6 +227,13 @@ class BaseProfileHandler(BaseHandler):
         if not by_admin and target_user != requester.user:
             raise AuthError(400, "Cannot set another user's avatar_url")
 
+        if not by_admin and not self.hs.config.enable_set_avatar_url:
+            profile = yield self.store.get_profileinfo(target_user.localpart)
+            if profile.avatar_url:
+                raise SynapseError(
+                    400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
+                )
+
         if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
             raise SynapseError(
                 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 4260426369..c3ee8db4f0 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -519,6 +519,9 @@ class RoomMemberHandler(object):
             yield self.store.set_room_is_public(old_room_id, False)
             yield self.store.set_room_is_public(room_id, True)
 
+        # Transfer alias mappings in the room directory
+        yield self.store.update_aliases_for_room(old_room_id, room_id)
+
         # Check if any groups we own contain the predecessor room
         local_group_ids = yield self.store.get_local_groups_for_room(old_room_id)
         for group_id in local_group_ids:
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 72c109981b..4741c82f61 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 import re
-from typing import Tuple
+from typing import Optional, Tuple
 
 import attr
 import saml2
@@ -26,6 +26,7 @@ from synapse.config import ConfigError
 from synapse.http.server import finish_request
 from synapse.http.servlet import parse_string
 from synapse.module_api import ModuleApi
+from synapse.module_api.errors import RedirectException
 from synapse.types import (
     UserID,
     map_username_to_mxid_localpart,
@@ -43,11 +44,15 @@ class Saml2SessionData:
 
     # time the session was created, in milliseconds
     creation_time = attr.ib()
+    # The user interactive authentication session ID associated with this SAML
+    # session (or None if this SAML session is for an initial login).
+    ui_auth_session_id = attr.ib(type=Optional[str], default=None)
 
 
 class SamlHandler:
     def __init__(self, hs):
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
+        self._auth = hs.get_auth()
         self._auth_handler = hs.get_auth_handler()
         self._registration_handler = hs.get_registration_handler()
 
@@ -76,12 +81,14 @@ class SamlHandler:
 
         self._error_html_content = hs.config.saml2_error_html_content
 
-    def handle_redirect_request(self, client_redirect_url):
+    def handle_redirect_request(self, client_redirect_url, ui_auth_session_id=None):
         """Handle an incoming request to /login/sso/redirect
 
         Args:
             client_redirect_url (bytes): the URL that we should redirect the
                 client to when everything is done
+            ui_auth_session_id (Optional[str]): The session ID of the ongoing UI Auth (or
+                None if this is a login).
 
         Returns:
             bytes: URL to redirect to
@@ -91,7 +98,9 @@ class SamlHandler:
         )
 
         now = self._clock.time_msec()
-        self._outstanding_requests_dict[reqid] = Saml2SessionData(creation_time=now)
+        self._outstanding_requests_dict[reqid] = Saml2SessionData(
+            creation_time=now, ui_auth_session_id=ui_auth_session_id,
+        )
 
         for key, value in info["headers"]:
             if key == "Location":
@@ -118,7 +127,12 @@ class SamlHandler:
         self.expire_sessions()
 
         try:
-            user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
+            user_id, current_session = await self._map_saml_response_to_user(
+                resp_bytes, relay_state
+            )
+        except RedirectException:
+            # Raise the exception as per the wishes of the SAML module response
+            raise
         except Exception as e:
             # If decoding the response or mapping it to a user failed, then log the
             # error and tell the user that something went wrong.
@@ -133,9 +147,28 @@ class SamlHandler:
             finish_request(request)
             return
 
-        self._auth_handler.complete_sso_login(user_id, request, relay_state)
+        # Complete the interactive auth session or the login.
+        if current_session and current_session.ui_auth_session_id:
+            self._auth_handler.complete_sso_ui_auth(
+                user_id, current_session.ui_auth_session_id, request
+            )
+
+        else:
+            self._auth_handler.complete_sso_login(user_id, request, relay_state)
+
+    async def _map_saml_response_to_user(
+        self, resp_bytes: str, client_redirect_url: str
+    ) -> Tuple[str, Optional[Saml2SessionData]]:
+        """
+        Given a sample response, retrieve the cached session and user for it.
 
-    async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url):
+        Args:
+            resp_bytes: The SAML response.
+            client_redirect_url: The redirect URL passed in by the client.
+
+        Returns:
+             Tuple of the user ID and SAML session associated with this response.
+        """
         try:
             saml2_auth = self._saml_client.parse_authn_request_response(
                 resp_bytes,
@@ -163,7 +196,9 @@ class SamlHandler:
 
         logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)
 
-        self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
+        current_session = self._outstanding_requests_dict.pop(
+            saml2_auth.in_response_to, None
+        )
 
         remote_user_id = self._user_mapping_provider.get_remote_user_id(
             saml2_auth, client_redirect_url
@@ -184,7 +219,7 @@ class SamlHandler:
             )
             if registered_user_id is not None:
                 logger.info("Found existing mapping %s", registered_user_id)
-                return registered_user_id
+                return registered_user_id, current_session
 
             # backwards-compatibility hack: see if there is an existing user with a
             # suitable mapping from the uid
@@ -209,7 +244,7 @@ class SamlHandler:
                     await self._datastore.record_user_external_id(
                         self._auth_provider_id, remote_user_id, registered_user_id
                     )
-                    return registered_user_id
+                    return registered_user_id, current_session
 
             # Map saml response to user attributes using the configured mapping provider
             for i in range(1000):
@@ -256,7 +291,7 @@ class SamlHandler:
             await self._datastore.record_user_external_id(
                 self._auth_provider_id, remote_user_id, registered_user_id
             )
-            return registered_user_id
+            return registered_user_id, current_session
 
     def expire_sessions(self):
         expire_before = self._clock.time_msec() - self._saml2_session_lifetime
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 12657ca698..7d1263caf2 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -32,6 +32,7 @@ class SetPasswordHandler(BaseHandler):
         super(SetPasswordHandler, self).__init__(hs)
         self._auth_handler = hs.get_auth_handler()
         self._device_handler = hs.get_device_handler()
+        self._password_policy_handler = hs.get_password_policy_handler()
 
     @defer.inlineCallbacks
     def set_password(
@@ -44,6 +45,7 @@ class SetPasswordHandler(BaseHandler):
         if not self.hs.config.password_localdb_enabled:
             raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
 
+        self._password_policy_handler.validate_password(new_password)
         password_hash = yield self._auth_handler.hash(new_password)
 
         try:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 669dbc8a48..1f1cde2feb 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -26,7 +26,7 @@ from prometheus_client import Counter
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.filtering import FilterCollection
 from synapse.events import EventBase
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import current_context
 from synapse.push.clientformat import format_push_rules_for_user
 from synapse.storage.roommember import MemberSummary
 from synapse.storage.state import StateFilter
@@ -301,7 +301,7 @@ class SyncHandler(object):
         else:
             sync_type = "incremental_sync"
 
-        context = LoggingContext.current_context()
+        context = current_context()
         if context:
             context.tag = sync_type
 
@@ -1143,9 +1143,14 @@ class SyncHandler(object):
                 user_id
             )
 
+            tracked_users = set(users_who_share_room)
+
+            # Always tell the user about their own devices
+            tracked_users.add(user_id)
+
             # Step 1a, check for changes in devices of users we share a room with
             users_that_have_changed = await self.store.get_users_whose_devices_changed(
-                since_token.device_list_key, users_who_share_room
+                since_token.device_list_key, tracked_users
             )
 
             # Step 1b, check for newly joined rooms
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 391bceb0c4..c7bc14c623 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,6 +15,7 @@
 
 import logging
 from collections import namedtuple
+from typing import List
 
 from twisted.internet import defer
 
@@ -257,7 +258,13 @@ class TypingHandler(object):
             "typing_key", self._latest_room_serial, rooms=[member.room_id]
         )
 
-    async def get_all_typing_updates(self, last_id, current_id):
+    async def get_all_typing_updates(
+        self, last_id: int, current_id: int, limit: int
+    ) -> List[dict]:
+        """Get up to `limit` typing updates between the given tokens, earliest
+        updates first.
+        """
+
         if last_id == current_id:
             return []
 
@@ -275,7 +282,7 @@ class TypingHandler(object):
                 typing = self._room_typing[room_id]
                 rows.append((serial, room_id, list(typing)))
         rows.sort()
-        return rows
+        return rows[:limit]
 
     def get_current_token(self):
         return self._latest_room_serial
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 58f9cc61c8..b58ae3d9db 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -19,7 +19,7 @@ import threading
 
 from prometheus_client.core import Counter, Histogram
 
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import current_context
 from synapse.metrics import LaterGauge
 
 logger = logging.getLogger(__name__)
@@ -148,7 +148,7 @@ LaterGauge(
 class RequestMetrics(object):
     def start(self, time_sec, name, method):
         self.start = time_sec
-        self.start_context = LoggingContext.current_context()
+        self.start_context = current_context()
         self.name = name
         self.method = method
 
@@ -163,7 +163,7 @@ class RequestMetrics(object):
         with _in_flight_requests_lock:
             _in_flight_requests.discard(self)
 
-        context = LoggingContext.current_context()
+        context = current_context()
 
         tag = ""
         if context:
diff --git a/synapse/http/site.py b/synapse/http/site.py
index e092193c9c..32feb0d968 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -193,6 +193,12 @@ class SynapseRequest(Request):
         self.finish_time = time.time()
         Request.connectionLost(self, reason)
 
+        if self.logcontext is None:
+            logger.info(
+                "Connection from %s lost before request headers were read", self.client
+            )
+            return
+
         # we only get here if the connection to the client drops before we send
         # the response.
         #
@@ -236,13 +242,6 @@ class SynapseRequest(Request):
     def _finished_processing(self):
         """Log the completion of this request and update the metrics
         """
-
-        if self.logcontext is None:
-            # this can happen if the connection closed before we read the
-            # headers (so render was never called). In that case we'll already
-            # have logged a warning, so just bail out.
-            return
-
         usage = self.logcontext.get_resource_usage()
 
         if self._processing_finished_time is None:
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index ffa7b20ca8..7372450b45 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -42,7 +42,7 @@ from synapse.logging._terse_json import (
     TerseJSONToConsoleLogObserver,
     TerseJSONToTCPLogObserver,
 )
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import current_context
 
 
 def stdlib_log_level_to_twisted(level: str) -> LogLevel:
@@ -86,7 +86,7 @@ class LogContextObserver(object):
             ].startswith("Timing out client"):
                 return
 
-        context = LoggingContext.current_context()
+        context = current_context()
 
         # Copy the context information to the log event.
         if context is not None:
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 860b99a4c6..3254d6a8df 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -175,7 +175,54 @@ class ContextResourceUsage(object):
         return res
 
 
-LoggingContextOrSentinel = Union["LoggingContext", "LoggingContext.Sentinel"]
+LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
+
+
+class _Sentinel(object):
+    """Sentinel to represent the root context"""
+
+    __slots__ = ["previous_context", "finished", "request", "scope", "tag"]
+
+    def __init__(self) -> None:
+        # Minimal set for compatibility with LoggingContext
+        self.previous_context = None
+        self.finished = False
+        self.request = None
+        self.scope = None
+        self.tag = None
+
+    def __str__(self):
+        return "sentinel"
+
+    def copy_to(self, record):
+        pass
+
+    def copy_to_twisted_log_entry(self, record):
+        record["request"] = None
+        record["scope"] = None
+
+    def start(self):
+        pass
+
+    def stop(self):
+        pass
+
+    def add_database_transaction(self, duration_sec):
+        pass
+
+    def add_database_scheduled(self, sched_sec):
+        pass
+
+    def record_event_fetch(self, event_count):
+        pass
+
+    def __nonzero__(self):
+        return False
+
+    __bool__ = __nonzero__  # python3
+
+
+SENTINEL_CONTEXT = _Sentinel()
 
 
 class LoggingContext(object):
@@ -199,76 +246,33 @@ class LoggingContext(object):
         "_resource_usage",
         "usage_start",
         "main_thread",
-        "alive",
+        "finished",
         "request",
         "tag",
         "scope",
     ]
 
-    thread_local = threading.local()
-
-    class Sentinel(object):
-        """Sentinel to represent the root context"""
-
-        __slots__ = ["previous_context", "alive", "request", "scope", "tag"]
-
-        def __init__(self) -> None:
-            # Minimal set for compatibility with LoggingContext
-            self.previous_context = None
-            self.alive = None
-            self.request = None
-            self.scope = None
-            self.tag = None
-
-        def __str__(self):
-            return "sentinel"
-
-        def copy_to(self, record):
-            pass
-
-        def copy_to_twisted_log_entry(self, record):
-            record["request"] = None
-            record["scope"] = None
-
-        def start(self):
-            pass
-
-        def stop(self):
-            pass
-
-        def add_database_transaction(self, duration_sec):
-            pass
-
-        def add_database_scheduled(self, sched_sec):
-            pass
-
-        def record_event_fetch(self, event_count):
-            pass
-
-        def __nonzero__(self):
-            return False
-
-        __bool__ = __nonzero__  # python3
-
-    sentinel = Sentinel()
-
     def __init__(self, name=None, parent_context=None, request=None) -> None:
-        self.previous_context = LoggingContext.current_context()
+        self.previous_context = current_context()
         self.name = name
 
         # track the resources used by this context so far
         self._resource_usage = ContextResourceUsage()
 
-        # If alive has the thread resource usage when the logcontext last
-        # became active.
+        # The thread resource usage when the logcontext became active. None
+        # if the context is not currently active.
         self.usage_start = None
 
         self.main_thread = get_thread_id()
         self.request = None
         self.tag = ""
-        self.alive = True
         self.scope = None  # type: Optional[_LogContextScope]
 
+        # keep track of whether we have hit the __exit__ block for this context
+        # (suggesting that the the thing that created the context thinks it should
+        # be finished, and that re-activating it would suggest an error).
+        self.finished = False
+
         self.parent_context = parent_context
 
         if self.parent_context is not None:
@@ -283,44 +287,15 @@ class LoggingContext(object):
             return str(self.request)
         return "%s@%x" % (self.name, id(self))
 
-    @classmethod
-    def current_context(cls) -> LoggingContextOrSentinel:
-        """Get the current logging context from thread local storage
-
-        Returns:
-            LoggingContext: the current logging context
-        """
-        return getattr(cls.thread_local, "current_context", cls.sentinel)
-
-    @classmethod
-    def set_current_context(
-        cls, context: LoggingContextOrSentinel
-    ) -> LoggingContextOrSentinel:
-        """Set the current logging context in thread local storage
-        Args:
-            context(LoggingContext): The context to activate.
-        Returns:
-            The context that was previously active
-        """
-        current = cls.current_context()
-
-        if current is not context:
-            current.stop()
-            cls.thread_local.current_context = context
-            context.start()
-        return current
-
     def __enter__(self) -> "LoggingContext":
         """Enters this logging context into thread local storage"""
-        old_context = self.set_current_context(self)
+        old_context = set_current_context(self)
         if self.previous_context != old_context:
             logger.warning(
                 "Expected previous context %r, found %r",
                 self.previous_context,
                 old_context,
             )
-        self.alive = True
-
         return self
 
     def __exit__(self, type, value, traceback) -> None:
@@ -329,24 +304,19 @@ class LoggingContext(object):
         Returns:
             None to avoid suppressing any exceptions that were thrown.
         """
-        current = self.set_current_context(self.previous_context)
+        current = set_current_context(self.previous_context)
         if current is not self:
-            if current is self.sentinel:
+            if current is SENTINEL_CONTEXT:
                 logger.warning("Expected logging context %s was lost", self)
             else:
                 logger.warning(
                     "Expected logging context %s but found %s", self, current
                 )
-        self.alive = False
 
-        # if we have a parent, pass our CPU usage stats on
-        if self.parent_context is not None and hasattr(
-            self.parent_context, "_resource_usage"
-        ):
-            self.parent_context._resource_usage += self._resource_usage
-
-            # reset them in case we get entered again
-            self._resource_usage.reset()
+        # the fact that we are here suggests that the caller thinks that everything
+        # is done and dusted for this logcontext, and further activity will not get
+        # recorded against the correct metrics.
+        self.finished = True
 
     def copy_to(self, record) -> None:
         """Copy logging fields from this context to a log record or
@@ -371,9 +341,14 @@ class LoggingContext(object):
             logger.warning("Started logcontext %s on different thread", self)
             return
 
+        if self.finished:
+            logger.warning("Re-starting finished log context %s", self)
+
         # If we haven't already started record the thread resource usage so
         # far
-        if not self.usage_start:
+        if self.usage_start:
+            logger.warning("Re-starting already-active log context %s", self)
+        else:
             self.usage_start = get_thread_resource_usage()
 
     def stop(self) -> None:
@@ -396,6 +371,15 @@ class LoggingContext(object):
 
         self.usage_start = None
 
+        # if we have a parent, pass our CPU usage stats on
+        if self.parent_context is not None and hasattr(
+            self.parent_context, "_resource_usage"
+        ):
+            self.parent_context._resource_usage += self._resource_usage
+
+            # reset them in case we get entered again
+            self._resource_usage.reset()
+
     def get_resource_usage(self) -> ContextResourceUsage:
         """Get resources used by this logcontext so far.
 
@@ -409,7 +393,7 @@ class LoggingContext(object):
         # If we are on the correct thread and we're currently running then we
         # can include resource usage so far.
         is_main_thread = get_thread_id() == self.main_thread
-        if self.alive and self.usage_start and is_main_thread:
+        if self.usage_start and is_main_thread:
             utime_delta, stime_delta = self._get_cputime()
             res.ru_utime += utime_delta
             res.ru_stime += stime_delta
@@ -492,7 +476,7 @@ class LoggingContextFilter(logging.Filter):
         Returns:
             True to include the record in the log output.
         """
-        context = LoggingContext.current_context()
+        context = current_context()
         for key, value in self.defaults.items():
             setattr(record, key, value)
 
@@ -512,27 +496,24 @@ class PreserveLoggingContext(object):
 
     __slots__ = ["current_context", "new_context", "has_parent"]
 
-    def __init__(self, new_context: Optional[LoggingContextOrSentinel] = None) -> None:
-        if new_context is None:
-            self.new_context = LoggingContext.sentinel  # type: LoggingContextOrSentinel
-        else:
-            self.new_context = new_context
+    def __init__(
+        self, new_context: LoggingContextOrSentinel = SENTINEL_CONTEXT
+    ) -> None:
+        self.new_context = new_context
 
     def __enter__(self) -> None:
         """Captures the current logging context"""
-        self.current_context = LoggingContext.set_current_context(self.new_context)
+        self.current_context = set_current_context(self.new_context)
 
         if self.current_context:
             self.has_parent = self.current_context.previous_context is not None
-            if not self.current_context.alive:
-                logger.debug("Entering dead context: %s", self.current_context)
 
     def __exit__(self, type, value, traceback) -> None:
         """Restores the current logging context"""
-        context = LoggingContext.set_current_context(self.current_context)
+        context = set_current_context(self.current_context)
 
         if context != self.new_context:
-            if context is LoggingContext.sentinel:
+            if not context:
                 logger.warning("Expected logging context %s was lost", self.new_context)
             else:
                 logger.warning(
@@ -541,9 +522,35 @@ class PreserveLoggingContext(object):
                     context,
                 )
 
-        if self.current_context is not LoggingContext.sentinel:
-            if not self.current_context.alive:
-                logger.debug("Restoring dead context: %s", self.current_context)
+
+_thread_local = threading.local()
+_thread_local.current_context = SENTINEL_CONTEXT
+
+
+def current_context() -> LoggingContextOrSentinel:
+    """Get the current logging context from thread local storage"""
+    return getattr(_thread_local, "current_context", SENTINEL_CONTEXT)
+
+
+def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSentinel:
+    """Set the current logging context in thread local storage
+    Args:
+        context(LoggingContext): The context to activate.
+    Returns:
+        The context that was previously active
+    """
+    # everything blows up if we allow current_context to be set to None, so sanity-check
+    # that now.
+    if context is None:
+        raise TypeError("'context' argument may not be None")
+
+    current = current_context()
+
+    if current is not context:
+        current.stop()
+        _thread_local.current_context = context
+        context.start()
+    return current
 
 
 def nested_logging_context(
@@ -572,7 +579,7 @@ def nested_logging_context(
     if parent_context is not None:
         context = parent_context  # type: LoggingContextOrSentinel
     else:
-        context = LoggingContext.current_context()
+        context = current_context()
     return LoggingContext(
         parent_context=context, request=str(context.request) + "-" + suffix
     )
@@ -604,7 +611,7 @@ def run_in_background(f, *args, **kwargs):
     CRITICAL error about an unhandled error will be logged without much
     indication about where it came from.
     """
-    current = LoggingContext.current_context()
+    current = current_context()
     try:
         res = f(*args, **kwargs)
     except:  # noqa: E722
@@ -625,7 +632,7 @@ def run_in_background(f, *args, **kwargs):
 
     # The function may have reset the context before returning, so
     # we need to restore it now.
-    ctx = LoggingContext.set_current_context(current)
+    ctx = set_current_context(current)
 
     # The original context will be restored when the deferred
     # completes, but there is nothing waiting for it, so it will
@@ -674,7 +681,7 @@ def make_deferred_yieldable(deferred):
 
     # ok, we can't be sure that a yield won't block, so let's reset the
     # logcontext, and add a callback to the deferred to restore it.
-    prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
+    prev_context = set_current_context(SENTINEL_CONTEXT)
     deferred.addBoth(_set_context_cb, prev_context)
     return deferred
 
@@ -684,7 +691,7 @@ ResultT = TypeVar("ResultT")
 
 def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
     """A callback function which just sets the logging context"""
-    LoggingContext.set_current_context(context)
+    set_current_context(context)
     return result
 
 
@@ -752,7 +759,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
         Deferred: A Deferred which fires a callback with the result of `f`, or an
             errback if `f` throws an exception.
     """
-    logcontext = LoggingContext.current_context()
+    logcontext = current_context()
 
     def g():
         with LoggingContext(parent_context=logcontext):
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index 4eed4f2338..dc3ab00cbb 100644
--- a/synapse/logging/scopecontextmanager.py
+++ b/synapse/logging/scopecontextmanager.py
@@ -19,7 +19,7 @@ from opentracing import Scope, ScopeManager
 
 import twisted
 
-from synapse.logging.context import LoggingContext, nested_logging_context
+from synapse.logging.context import current_context, nested_logging_context
 
 logger = logging.getLogger(__name__)
 
@@ -49,11 +49,8 @@ class LogContextScopeManager(ScopeManager):
             (Scope) : the Scope that is active, or None if not
             available.
         """
-        ctx = LoggingContext.current_context()
-        if ctx is LoggingContext.sentinel:
-            return None
-        else:
-            return ctx.scope
+        ctx = current_context()
+        return ctx.scope
 
     def activate(self, span, finish_on_close):
         """
@@ -70,9 +67,9 @@ class LogContextScopeManager(ScopeManager):
         """
 
         enter_logcontext = False
-        ctx = LoggingContext.current_context()
+        ctx = current_context()
 
-        if ctx is LoggingContext.sentinel:
+        if not ctx:
             # We don't want this scope to affect.
             logger.error("Tried to activate scope outside of loggingcontext")
             return Scope(None, span)
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 28dbc6fcba..4613b2538c 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -21,6 +21,7 @@ from synapse.replication.http import (
     membership,
     register,
     send_event,
+    streams,
 )
 
 REPLICATION_PREFIX = "/_synapse/replication"
@@ -38,3 +39,4 @@ class ReplicationRestResource(JsonResource):
         login.register_servlets(hs, self)
         register.register_servlets(hs, self)
         devices.register_servlets(hs, self)
+        streams.register_servlets(hs, self)
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
new file mode 100644
index 0000000000..ffd4c61993
--- /dev/null
+++ b/synapse/replication/http/streams.py
@@ -0,0 +1,78 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import parse_integer
+from synapse.replication.http._base import ReplicationEndpoint
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationGetStreamUpdates(ReplicationEndpoint):
+    """Fetches stream updates from a server. Used for streams not persisted to
+    the database, e.g. typing notifications.
+
+    The API looks like:
+
+        GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100
+
+        200 OK
+
+        {
+            updates: [ ... ],
+            upto_token: 10,
+            limited: False,
+        }
+
+    """
+
+    NAME = "get_repl_stream_updates"
+    PATH_ARGS = ("stream_name",)
+    METHOD = "GET"
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        # We pull the streams from the replication steamer (if we try and make
+        # them ourselves we end up in an import loop).
+        self.streams = hs.get_replication_streamer().get_streams()
+
+    @staticmethod
+    def _serialize_payload(stream_name, from_token, upto_token, limit):
+        return {"from_token": from_token, "upto_token": upto_token, "limit": limit}
+
+    async def _handle_request(self, request, stream_name):
+        stream = self.streams.get(stream_name)
+        if stream is None:
+            raise SynapseError(400, "Unknown stream")
+
+        from_token = parse_integer(request, "from_token", required=True)
+        upto_token = parse_integer(request, "upto_token", required=True)
+        limit = parse_integer(request, "limit", required=True)
+
+        updates, upto_token, limited = await stream.get_updates_since(
+            from_token, upto_token, limit
+        )
+
+        return (
+            200,
+            {"updates": updates, "upto_token": upto_token, "limited": limited},
+        )
+
+
+def register_servlets(hs, http_server):
+    ReplicationGetStreamUpdates(hs).register(http_server)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index f45cbd37a0..751c799d94 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -18,8 +18,10 @@ from typing import Dict, Optional
 
 import six
 
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
+from synapse.storage.data_stores.main.cache import (
+    CURRENT_STATE_CACHE_NAME,
+    CacheInvalidationWorkerStore,
+)
 from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 
@@ -35,7 +37,7 @@ def __func__(inp):
         return inp.__func__
 
 
-class BaseSlavedStore(SQLBaseStore):
+class BaseSlavedStore(CacheInvalidationWorkerStore):
     def __init__(self, database: Database, db_conn, hs):
         super(BaseSlavedStore, self).__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
@@ -60,6 +62,12 @@ class BaseSlavedStore(SQLBaseStore):
             pos["caches"] = self._cache_id_gen.get_current_token()
         return pos
 
+    def get_cache_stream_token(self):
+        if self._cache_id_gen:
+            return self._cache_id_gen.get_current_token()
+        else:
+            return 0
+
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "caches":
             if self._cache_id_gen:
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 1c77687eea..23b1650e41 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -29,7 +29,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
         self.hs = hs
 
         self._device_list_id_gen = SlavedIdTracker(
-            db_conn, "device_lists_stream", "stream_id"
+            db_conn,
+            "device_lists_stream",
+            "stream_id",
+            extra_tables=[
+                ("user_signature_stream", "stream_id"),
+                ("device_lists_outbound_pokes", "stream_id"),
+            ],
         )
         device_list_max = self._device_list_id_gen.get_current_token()
         self._device_list_stream_cache = StreamChangeCache(
@@ -55,23 +61,27 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == DeviceListsStream.NAME:
             self._device_list_id_gen.advance(token)
-            for row in rows:
-                self._invalidate_caches_for_devices(token, row.user_id, row.destination)
+            self._invalidate_caches_for_devices(token, rows)
         elif stream_name == UserSignatureStream.NAME:
+            self._device_list_id_gen.advance(token)
             for row in rows:
                 self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
         return super(SlavedDeviceStore, self).process_replication_rows(
             stream_name, token, rows
         )
 
-    def _invalidate_caches_for_devices(self, token, user_id, destination):
-        self._device_list_stream_cache.entity_has_changed(user_id, token)
-
-        if destination:
-            self._device_list_federation_stream_cache.entity_has_changed(
-                destination, token
-            )
+    def _invalidate_caches_for_devices(self, token, rows):
+        for row in rows:
+            # The entities are either user IDs (starting with '@') whose devices
+            # have changed, or remote servers that we need to tell about
+            # changes.
+            if row.entity.startswith("@"):
+                self._device_list_stream_cache.entity_has_changed(row.entity, token)
+                self.get_cached_devices_for_user.invalidate((row.entity,))
+                self._get_cached_user_device.invalidate_many((row.entity,))
+                self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
 
-        self.get_cached_devices_for_user.invalidate((user_id,))
-        self._get_cached_user_device.invalidate_many((user_id,))
-        self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
+            else:
+                self._device_list_federation_stream_cache.entity_has_changed(
+                    row.entity, token
+                )
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index f22c2d44a3..bce8a3d115 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -33,6 +33,9 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
         result["pushers"] = self._pushers_id_gen.get_current_token()
         return result
 
+    def get_pushers_stream_token(self):
+        return self._pushers_id_gen.get_current_token()
+
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "pushers":
             self._pushers_id_gen.advance(token)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 02ab5b66ea..e86d9805f1 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -55,6 +55,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
         self.client_name = client_name
         self.handler = handler
         self.server_name = hs.config.server_name
+        self.hs = hs
         self._clock = hs.get_clock()  # As self.clock is defined in super class
 
         hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
@@ -65,7 +66,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
     def buildProtocol(self, addr):
         logger.info("Connected to replication: %r", addr)
         return ClientReplicationStreamProtocol(
-            self.client_name, self.server_name, self._clock, self.handler
+            self.hs, self.client_name, self.server_name, self._clock, self.handler,
         )
 
     def clientConnectionLost(self, connector, reason):
@@ -188,10 +189,12 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
         """
         self.send_command(FederationAckCommand(token))
 
-    def send_user_sync(self, user_id, is_syncing, last_sync_ms):
+    def send_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms):
         """Poke the master that a user has started/stopped syncing.
         """
-        self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
+        self.send_command(
+            UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
+        )
 
     def send_remove_pusher(self, app_id, push_key, user_id):
         """Poke the master to remove a pusher for a user
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 451671412d..e4eec643f7 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -136,8 +136,8 @@ class PositionCommand(Command):
     """Sent by the server to tell the client the stream postition without
     needing to send an RDATA.
 
-    Sent to the client after all missing updates for a stream have been sent
-    to the client and they're now up to date.
+    On receipt of a POSITION command clients should check if they have missed
+    any updates, and if so then fetch them out of band.
     """
 
     NAME = "POSITION"
@@ -179,42 +179,24 @@ class NameCommand(Command):
 
 
 class ReplicateCommand(Command):
-    """Sent by the client to subscribe to the stream.
+    """Sent by the client to subscribe to streams.
 
     Format::
 
-        REPLICATE <stream_name> <token>
-
-    Where <token> may be either:
-        * a numeric stream_id to stream updates from
-        * "NOW" to stream all subsequent updates.
-
-    The <stream_name> can be "ALL" to subscribe to all known streams, in which
-    case the <token> must be set to "NOW", i.e.::
-
-        REPLICATE ALL NOW
+        REPLICATE
     """
 
     NAME = "REPLICATE"
 
-    def __init__(self, stream_name, token):
-        self.stream_name = stream_name
-        self.token = token
+    def __init__(self):
+        pass
 
     @classmethod
     def from_line(cls, line):
-        stream_name, token = line.split(" ", 1)
-        if token in ("NOW", "now"):
-            token = "NOW"
-        else:
-            token = int(token)
-        return cls(stream_name, token)
+        return cls()
 
     def to_line(self):
-        return " ".join((self.stream_name, str(self.token)))
-
-    def get_logcontext_id(self):
-        return "REPLICATE-" + self.stream_name
+        return ""
 
 
 class UserSyncCommand(Command):
@@ -225,30 +207,32 @@ class UserSyncCommand(Command):
 
     Format::
 
-        USER_SYNC <user_id> <state> <last_sync_ms>
+        USER_SYNC <instance_id> <user_id> <state> <last_sync_ms>
 
     Where <state> is either "start" or "stop"
     """
 
     NAME = "USER_SYNC"
 
-    def __init__(self, user_id, is_syncing, last_sync_ms):
+    def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
+        self.instance_id = instance_id
         self.user_id = user_id
         self.is_syncing = is_syncing
         self.last_sync_ms = last_sync_ms
 
     @classmethod
     def from_line(cls, line):
-        user_id, state, last_sync_ms = line.split(" ", 2)
+        instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
 
         if state not in ("start", "end"):
             raise Exception("Invalid USER_SYNC state %r" % (state,))
 
-        return cls(user_id, state == "start", int(last_sync_ms))
+        return cls(instance_id, user_id, state == "start", int(last_sync_ms))
 
     def to_line(self):
         return " ".join(
             (
+                self.instance_id,
                 self.user_id,
                 "start" if self.is_syncing else "end",
                 str(self.last_sync_ms),
@@ -256,6 +240,30 @@ class UserSyncCommand(Command):
         )
 
 
+class ClearUserSyncsCommand(Command):
+    """Sent by the client to inform the server that it should drop all
+    information about syncing users sent by the client.
+
+    Mainly used when client is about to shut down.
+
+    Format::
+
+        CLEAR_USER_SYNC <instance_id>
+    """
+
+    NAME = "CLEAR_USER_SYNC"
+
+    def __init__(self, instance_id):
+        self.instance_id = instance_id
+
+    @classmethod
+    def from_line(cls, line):
+        return cls(line)
+
+    def to_line(self):
+        return self.instance_id
+
+
 class FederationAckCommand(Command):
     """Sent by the client when it has processed up to a given point in the
     federation stream. This allows the master to drop in-memory caches of the
@@ -416,6 +424,7 @@ _COMMANDS = (
     InvalidateCacheCommand,
     UserIpCommand,
     RemoteServerUpCommand,
+    ClearUserSyncsCommand,
 )  # type: Tuple[Type[Command], ...]
 
 # Map of command name to command type.
@@ -438,6 +447,7 @@ VALID_CLIENT_COMMANDS = (
     ReplicateCommand.NAME,
     PingCommand.NAME,
     UserSyncCommand.NAME,
+    ClearUserSyncsCommand.NAME,
     FederationAckCommand.NAME,
     RemovePusherCommand.NAME,
     InvalidateCacheCommand.NAME,
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index bc1482a9bb..dae246825f 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire::
     > PING 1490197665618
     < NAME synapse.app.appservice
     < PING 1490197665618
-    < REPLICATE events 1
-    < REPLICATE backfill 1
-    < REPLICATE caches 1
+    < REPLICATE
     > POSITION events 1
     > POSITION backfill 1
     > POSITION caches 1
@@ -53,17 +51,15 @@ import fcntl
 import logging
 import struct
 from collections import defaultdict
-from typing import Any, DefaultDict, Dict, List, Set, Tuple
+from typing import Any, DefaultDict, Dict, List, Set
 
-from six import iteritems, iterkeys
+from six import iteritems
 
 from prometheus_client import Counter
 
-from twisted.internet import defer
 from twisted.protocols.basic import LineOnlyReceiver
 from twisted.python.failure import Failure
 
-from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.commands import (
@@ -82,11 +78,16 @@ from synapse.replication.tcp.commands import (
     SyncCommand,
     UserSyncCommand,
 )
-from synapse.replication.tcp.streams import STREAMS_MAP
+from synapse.replication.tcp.streams import STREAMS_MAP, Stream
 from synapse.types import Collection
 from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
+MYPY = False
+if MYPY:
+    from synapse.server import HomeServer
+
+
 connection_close_counter = Counter(
     "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
 )
@@ -411,16 +412,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         self.server_name = server_name
         self.streamer = streamer
 
-        # The streams the client has subscribed to and is up to date with
-        self.replication_streams = set()  # type: Set[str]
-
-        # The streams the client is currently subscribing to.
-        self.connecting_streams = set()  # type:  Set[str]
-
-        # Map from stream name to list of updates to send once we've finished
-        # subscribing the client to the stream.
-        self.pending_rdata = {}  # type: Dict[str, List[Tuple[int, Any]]]
-
     def connectionMade(self):
         self.send_command(ServerCommand(self.server_name))
         BaseReplicationStreamProtocol.connectionMade(self)
@@ -432,25 +423,17 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
     async def on_USER_SYNC(self, cmd):
         await self.streamer.on_user_sync(
-            self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
+            cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
         )
 
-    async def on_REPLICATE(self, cmd):
-        stream_name = cmd.stream_name
-        token = cmd.token
-
-        if stream_name == "ALL":
-            # Subscribe to all streams we're publishing to.
-            deferreds = [
-                run_in_background(self.subscribe_to_stream, stream, token)
-                for stream in iterkeys(self.streamer.streams_by_name)
-            ]
+    async def on_CLEAR_USER_SYNC(self, cmd):
+        await self.streamer.on_clear_user_syncs(cmd.instance_id)
 
-            await make_deferred_yieldable(
-                defer.gatherResults(deferreds, consumeErrors=True)
-            )
-        else:
-            await self.subscribe_to_stream(stream_name, token)
+    async def on_REPLICATE(self, cmd):
+        # Subscribe to all streams we're publishing to.
+        for stream_name in self.streamer.streams_by_name:
+            current_token = self.streamer.get_stream_token(stream_name)
+            self.send_command(PositionCommand(stream_name, current_token))
 
     async def on_FEDERATION_ACK(self, cmd):
         self.streamer.federation_ack(cmd.token)
@@ -474,87 +457,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
             cmd.last_seen,
         )
 
-    async def subscribe_to_stream(self, stream_name, token):
-        """Subscribe the remote to a stream.
-
-        This invloves checking if they've missed anything and sending those
-        updates down if they have. During that time new updates for the stream
-        are queued and sent once we've sent down any missed updates.
-        """
-        self.replication_streams.discard(stream_name)
-        self.connecting_streams.add(stream_name)
-
-        try:
-            # Get missing updates
-            updates, current_token = await self.streamer.get_stream_updates(
-                stream_name, token
-            )
-
-            # Send all the missing updates
-            for update in updates:
-                token, row = update[0], update[1]
-                self.send_command(RdataCommand(stream_name, token, row))
-
-            # We send a POSITION command to ensure that they have an up to
-            # date token (especially useful if we didn't send any updates
-            # above)
-            self.send_command(PositionCommand(stream_name, current_token))
-
-            # Now we can send any updates that came in while we were subscribing
-            pending_rdata = self.pending_rdata.pop(stream_name, [])
-            updates = []
-            for token, update in pending_rdata:
-                # If the token is null, it is part of a batch update. Batches
-                # are multiple updates that share a single token. To denote
-                # this, the token is set to None for all tokens in the batch
-                # except for the last. If we find a None token, we keep looking
-                # through tokens until we find one that is not None and then
-                # process all previous updates in the batch as if they had the
-                # final token.
-                if token is None:
-                    # Store this update as part of a batch
-                    updates.append(update)
-                    continue
-
-                if token <= current_token:
-                    # This update or batch of updates is older than
-                    # current_token, dismiss it
-                    updates = []
-                    continue
-
-                updates.append(update)
-
-                # Send all updates that are part of this batch with the
-                # found token
-                for update in updates:
-                    self.send_command(RdataCommand(stream_name, token, update))
-
-                # Clear stored updates
-                updates = []
-
-            # They're now fully subscribed
-            self.replication_streams.add(stream_name)
-        except Exception as e:
-            logger.exception("[%s] Failed to handle REPLICATE command", self.id())
-            self.send_error("failed to handle replicate: %r", e)
-        finally:
-            self.connecting_streams.discard(stream_name)
-
     def stream_update(self, stream_name, token, data):
         """Called when a new update is available to stream to clients.
 
         We need to check if the client is interested in the stream or not
         """
-        if stream_name in self.replication_streams:
-            # The client is subscribed to the stream
-            self.send_command(RdataCommand(stream_name, token, data))
-        elif stream_name in self.connecting_streams:
-            # The client is being subscribed to the stream
-            logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
-            self.pending_rdata.setdefault(stream_name, []).append((token, data))
-        else:
-            # The client isn't subscribed
-            logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
+        self.send_command(RdataCommand(stream_name, token, data))
 
     def send_sync(self, data):
         self.send_command(SyncCommand(data))
@@ -638,6 +546,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
     def __init__(
         self,
+        hs: "HomeServer",
         client_name: str,
         server_name: str,
         clock: Clock,
@@ -645,41 +554,42 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
     ):
         BaseReplicationStreamProtocol.__init__(self, clock)
 
+        self.instance_id = hs.get_instance_id()
+
         self.client_name = client_name
         self.server_name = server_name
         self.handler = handler
 
+        self.streams = {
+            stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
+        }  # type: Dict[str, Stream]
+
         # Set of stream names that have been subscribe to, but haven't yet
         # caught up with. This is used to track when the client has been fully
         # connected to the remote.
-        self.streams_connecting = set()  # type: Set[str]
+        self.streams_connecting = set(STREAMS_MAP)  # type: Set[str]
 
         # Map of stream to batched updates. See RdataCommand for info on how
         # batching works.
-        self.pending_batches = {}  # type: Dict[str, Any]
+        self.pending_batches = {}  # type: Dict[str, List[Any]]
 
     def connectionMade(self):
         self.send_command(NameCommand(self.client_name))
         BaseReplicationStreamProtocol.connectionMade(self)
 
         # Once we've connected subscribe to the necessary streams
-        for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
-            self.replicate(stream_name, token)
+        self.replicate()
 
         # Tell the server if we have any users currently syncing (should only
         # happen on synchrotrons)
         currently_syncing = self.handler.get_currently_syncing_users()
         now = self.clock.time_msec()
         for user_id in currently_syncing:
-            self.send_command(UserSyncCommand(user_id, True, now))
+            self.send_command(UserSyncCommand(self.instance_id, user_id, True, now))
 
         # We've now finished connecting to so inform the client handler
         self.handler.update_connection(self)
 
-        # This will happen if we don't actually subscribe to any streams
-        if not self.streams_connecting:
-            self.handler.finished_connecting()
-
     async def on_SERVER(self, cmd):
         if cmd.data != self.server_name:
             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
@@ -697,7 +607,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
             )
             raise
 
-        if cmd.token is None:
+        if cmd.token is None or stream_name in self.streams_connecting:
             # I.e. this is part of a batch of updates for this stream. Batch
             # until we get an update for the stream with a non None token
             self.pending_batches.setdefault(stream_name, []).append(row)
@@ -707,14 +617,55 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
             rows.append(row)
             await self.handler.on_rdata(stream_name, cmd.token, rows)
 
-    async def on_POSITION(self, cmd):
-        # When we get a `POSITION` command it means we've finished getting
-        # missing updates for the given stream, and are now up to date.
+    async def on_POSITION(self, cmd: PositionCommand):
+        stream = self.streams.get(cmd.stream_name)
+        if not stream:
+            logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
+            return
+
+        # Find where we previously streamed up to.
+        current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
+        if current_token is None:
+            logger.warning(
+                "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
+            )
+            return
+
+        # Fetch all updates between then and now.
+        limited = True
+        while limited:
+            updates, current_token, limited = await stream.get_updates_since(
+                current_token, cmd.token
+            )
+
+            # Check if the connection was closed underneath us, if so we bail
+            # rather than risk having concurrent catch ups going on.
+            if self.state == ConnectionStates.CLOSED:
+                return
+
+            if updates:
+                await self.handler.on_rdata(
+                    cmd.stream_name,
+                    current_token,
+                    [stream.parse_row(update[1]) for update in updates],
+                )
+
+        # We've now caught up to position sent to us, notify handler.
+        await self.handler.on_position(cmd.stream_name, cmd.token)
+
         self.streams_connecting.discard(cmd.stream_name)
         if not self.streams_connecting:
             self.handler.finished_connecting()
 
-        await self.handler.on_position(cmd.stream_name, cmd.token)
+        # Check if the connection was closed underneath us, if so we bail
+        # rather than risk having concurrent catch ups going on.
+        if self.state == ConnectionStates.CLOSED:
+            return
+
+        # Handle any RDATA that came in while we were catching up.
+        rows = self.pending_batches.pop(cmd.stream_name, [])
+        if rows:
+            await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
 
     async def on_SYNC(self, cmd):
         self.handler.on_sync(cmd.data)
@@ -722,22 +673,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
     async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
         self.handler.on_remote_server_up(cmd.data)
 
-    def replicate(self, stream_name, token):
+    def replicate(self):
         """Send the subscription request to the server
         """
-        if stream_name not in STREAMS_MAP:
-            raise Exception("Invalid stream name %r" % (stream_name,))
-
-        logger.info(
-            "[%s] Subscribing to replication stream: %r from %r",
-            self.id(),
-            stream_name,
-            token,
-        )
-
-        self.streams_connecting.add(stream_name)
+        logger.info("[%s] Subscribing to replication streams", self.id())
 
-        self.send_command(ReplicateCommand(stream_name, token))
+        self.send_command(ReplicateCommand())
 
     def on_connection_closed(self):
         BaseReplicationStreamProtocol.on_connection_closed(self)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index ce9d1fae12..30021ee309 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,7 +17,7 @@
 
 import logging
 import random
-from typing import Any, List
+from typing import Any, Dict, List
 
 from six import itervalues
 
@@ -30,7 +30,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.util.metrics import Measure, measure_func
 
 from .protocol import ServerReplicationStreamProtocol
-from .streams import STREAMS_MAP
+from .streams import STREAMS_MAP, Stream
 from .streams.federation import FederationStream
 
 stream_updates_counter = Counter(
@@ -52,7 +52,7 @@ class ReplicationStreamProtocolFactory(Factory):
     """
 
     def __init__(self, hs):
-        self.streamer = ReplicationStreamer(hs)
+        self.streamer = hs.get_replication_streamer()
         self.clock = hs.get_clock()
         self.server_name = hs.config.server_name
 
@@ -99,22 +99,6 @@ class ReplicationStreamer(object):
 
         self.streams_by_name = {stream.NAME: stream for stream in self.streams}
 
-        LaterGauge(
-            "synapse_replication_tcp_resource_connections_per_stream",
-            "",
-            ["stream_name"],
-            lambda: {
-                (stream_name,): len(
-                    [
-                        conn
-                        for conn in self.connections
-                        if stream_name in conn.replication_streams
-                    ]
-                )
-                for stream_name in self.streams_by_name
-            },
-        )
-
         self.federation_sender = None
         if not hs.config.send_federation:
             self.federation_sender = hs.get_federation_sender()
@@ -133,6 +117,11 @@ class ReplicationStreamer(object):
         for conn in self.connections:
             conn.send_error("server shutting down")
 
+    def get_streams(self) -> Dict[str, Stream]:
+        """Get a mapp from stream name to stream instance.
+        """
+        return self.streams_by_name
+
     def on_notifier_poke(self):
         """Checks if there is actually any new data and sends it to the
         connections if there are.
@@ -166,11 +155,6 @@ class ReplicationStreamer(object):
                 self.pending_updates = False
 
                 with Measure(self.clock, "repl.stream.get_updates"):
-                    # First we tell the streams that they should update their
-                    # current tokens.
-                    for stream in self.streams:
-                        stream.advance_current_token()
-
                     all_streams = self.streams
 
                     if self._replication_torture_level is not None:
@@ -180,7 +164,7 @@ class ReplicationStreamer(object):
                         random.shuffle(all_streams)
 
                     for stream in all_streams:
-                        if stream.last_token == stream.upto_token:
+                        if stream.last_token == stream.current_token():
                             continue
 
                         if self._replication_torture_level:
@@ -192,10 +176,11 @@ class ReplicationStreamer(object):
                             "Getting stream: %s: %s -> %s",
                             stream.NAME,
                             stream.last_token,
-                            stream.upto_token,
+                            stream.current_token(),
                         )
                         try:
-                            updates, current_token = await stream.get_updates()
+                            updates, current_token, limited = await stream.get_updates()
+                            self.pending_updates |= limited
                         except Exception:
                             logger.info("Failed to handle stream %s", stream.NAME)
                             raise
@@ -231,8 +216,7 @@ class ReplicationStreamer(object):
             self.pending_updates = False
             self.is_looping = False
 
-    @measure_func("repl.get_stream_updates")
-    async def get_stream_updates(self, stream_name, token):
+    def get_stream_token(self, stream_name):
         """For a given stream get all updates since token. This is called when
         a client first subscribes to a stream.
         """
@@ -240,7 +224,7 @@ class ReplicationStreamer(object):
         if not stream:
             raise Exception("unknown stream %s", stream_name)
 
-        return await stream.get_updates_since(token)
+        return stream.current_token()
 
     @measure_func("repl.federation_ack")
     def federation_ack(self, token):
@@ -251,14 +235,19 @@ class ReplicationStreamer(object):
             self.federation_sender.federation_ack(token)
 
     @measure_func("repl.on_user_sync")
-    async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
+    async def on_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms):
         """A client has started/stopped syncing on a worker.
         """
         user_sync_counter.inc()
         await self.presence_handler.update_external_syncs_row(
-            conn_id, user_id, is_syncing, last_sync_ms
+            instance_id, user_id, is_syncing, last_sync_ms
         )
 
+    async def on_clear_user_syncs(self, instance_id):
+        """A replication client wants us to drop all their UserSync data.
+        """
+        await self.presence_handler.update_external_syncs_clear(instance_id)
+
     @measure_func("repl.on_remove_pusher")
     async def on_remove_pusher(self, app_id, push_key, user_id):
         """A client has asked us to remove a pusher
@@ -321,14 +310,6 @@ class ReplicationStreamer(object):
         except ValueError:
             pass
 
-        # We need to tell the presence handler that the connection has been
-        # lost so that it can handle any ongoing syncs on that connection.
-        run_as_background_process(
-            "update_external_syncs_clear",
-            self.presence_handler.update_external_syncs_clear,
-            connection.conn_id,
-        )
-
 
 def _batch_updates(updates):
     """Takes a list of updates of form [(token, row)] and sets the token to
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 5f52264e84..37bcd3de66 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -25,26 +25,66 @@ Each stream is defined by the following information:
     update_function:    The function that returns a list of updates between two tokens
 """
 
-from . import _base, events, federation
+from typing import Dict, Type
+
+from synapse.replication.tcp.streams._base import (
+    AccountDataStream,
+    BackfillStream,
+    CachesStream,
+    DeviceListsStream,
+    GroupServerStream,
+    PresenceStream,
+    PublicRoomsStream,
+    PushersStream,
+    PushRulesStream,
+    ReceiptsStream,
+    Stream,
+    TagAccountDataStream,
+    ToDeviceStream,
+    TypingStream,
+    UserSignatureStream,
+)
+from synapse.replication.tcp.streams.events import EventsStream
+from synapse.replication.tcp.streams.federation import FederationStream
 
 STREAMS_MAP = {
     stream.NAME: stream
     for stream in (
-        events.EventsStream,
-        _base.BackfillStream,
-        _base.PresenceStream,
-        _base.TypingStream,
-        _base.ReceiptsStream,
-        _base.PushRulesStream,
-        _base.PushersStream,
-        _base.CachesStream,
-        _base.PublicRoomsStream,
-        _base.DeviceListsStream,
-        _base.ToDeviceStream,
-        federation.FederationStream,
-        _base.TagAccountDataStream,
-        _base.AccountDataStream,
-        _base.GroupServerStream,
-        _base.UserSignatureStream,
+        EventsStream,
+        BackfillStream,
+        PresenceStream,
+        TypingStream,
+        ReceiptsStream,
+        PushRulesStream,
+        PushersStream,
+        CachesStream,
+        PublicRoomsStream,
+        DeviceListsStream,
+        ToDeviceStream,
+        FederationStream,
+        TagAccountDataStream,
+        AccountDataStream,
+        GroupServerStream,
+        UserSignatureStream,
     )
-}
+}  # type: Dict[str, Type[Stream]]
+
+
+__all__ = [
+    "STREAMS_MAP",
+    "Stream",
+    "BackfillStream",
+    "PresenceStream",
+    "TypingStream",
+    "ReceiptsStream",
+    "PushRulesStream",
+    "PushersStream",
+    "CachesStream",
+    "PublicRoomsStream",
+    "DeviceListsStream",
+    "ToDeviceStream",
+    "TagAccountDataStream",
+    "AccountDataStream",
+    "GroupServerStream",
+    "UserSignatureStream",
+]
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 208e8a667b..c14dff6c64 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,114 +14,40 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import itertools
 import logging
 from collections import namedtuple
-from typing import Any, List, Optional
+from typing import Any, Awaitable, Callable, List, Optional, Tuple
 
 import attr
 
+from synapse.replication.http.streams import ReplicationGetStreamUpdates
+from synapse.types import JsonDict
+
 logger = logging.getLogger(__name__)
 
 
 MAX_EVENTS_BEHIND = 500000
 
-BackfillStreamRow = namedtuple(
-    "BackfillStreamRow",
-    (
-        "event_id",  # str
-        "room_id",  # str
-        "type",  # str
-        "state_key",  # str, optional
-        "redacts",  # str, optional
-        "relates_to",  # str, optional
-    ),
-)
-PresenceStreamRow = namedtuple(
-    "PresenceStreamRow",
-    (
-        "user_id",  # str
-        "state",  # str
-        "last_active_ts",  # int
-        "last_federation_update_ts",  # int
-        "last_user_sync_ts",  # int
-        "status_msg",  # str
-        "currently_active",  # bool
-    ),
-)
-TypingStreamRow = namedtuple(
-    "TypingStreamRow", ("room_id", "user_ids")  # str  # list(str)
-)
-ReceiptsStreamRow = namedtuple(
-    "ReceiptsStreamRow",
-    (
-        "room_id",  # str
-        "receipt_type",  # str
-        "user_id",  # str
-        "event_id",  # str
-        "data",  # dict
-    ),
-)
-PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",))  # str
-PushersStreamRow = namedtuple(
-    "PushersStreamRow",
-    ("user_id", "app_id", "pushkey", "deleted"),  # str  # str  # str  # bool
-)
-
-
-@attr.s
-class CachesStreamRow:
-    """Stream to inform workers they should invalidate their cache.
-
-    Attributes:
-        cache_func: Name of the cached function.
-        keys: The entry in the cache to invalidate. If None then will
-            invalidate all.
-        invalidation_ts: Timestamp of when the invalidation took place.
-    """
 
-    cache_func = attr.ib(type=str)
-    keys = attr.ib(type=Optional[List[Any]])
-    invalidation_ts = attr.ib(type=int)
-
-
-PublicRoomsStreamRow = namedtuple(
-    "PublicRoomsStreamRow",
-    (
-        "room_id",  # str
-        "visibility",  # str
-        "appservice_id",  # str, optional
-        "network_id",  # str, optional
-    ),
-)
-DeviceListsStreamRow = namedtuple(
-    "DeviceListsStreamRow", ("user_id", "destination")  # str  # str
-)
-ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",))  # str
-TagAccountDataStreamRow = namedtuple(
-    "TagAccountDataStreamRow", ("user_id", "room_id", "data")  # str  # str  # dict
-)
-AccountDataStreamRow = namedtuple(
-    "AccountDataStream", ("user_id", "room_id", "data_type")  # str  # str  # str
-)
-GroupsStreamRow = namedtuple(
-    "GroupsStreamRow",
-    ("group_id", "user_id", "type", "content"),  # str  # str  # str  # dict
-)
-UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id"))  # str
+# Some type aliases to make things a bit easier.
+
+# A stream position token
+Token = int
+
+# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
+StreamRow = Tuple[Token, tuple]
 
 
 class Stream(object):
     """Base class for the streams.
 
     Provides a `get_updates()` function that returns new updates since the last
-    time it was called up until the point `advance_current_token` was called.
+    time it was called.
     """
 
     NAME = None  # type: str  # The name of the stream
     # The type of the row. Used by the default impl of parse_row.
     ROW_TYPE = None  # type: Any
-    _LIMITED = True  # Whether the update function takes a limit
 
     @classmethod
     def parse_row(cls, row):
@@ -139,80 +65,56 @@ class Stream(object):
         return cls.ROW_TYPE(*row)
 
     def __init__(self, hs):
+
         # The token from which we last asked for updates
         self.last_token = self.current_token()
 
-        # The token that we will get updates up to
-        self.upto_token = self.current_token()
-
-    def advance_current_token(self):
-        """Updates `upto_token` to "now", which updates up until which point
-        get_updates[_since] will fetch rows till.
-        """
-        self.upto_token = self.current_token()
-
     def discard_updates_and_advance(self):
         """Called when the stream should advance but the updates would be discarded,
         e.g. when there are no currently connected workers.
         """
-        self.upto_token = self.current_token()
-        self.last_token = self.upto_token
+        self.last_token = self.current_token()
 
-    async def get_updates(self):
+    async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
         """Gets all updates since the last time this function was called (or
-        since the stream was constructed if it hadn't been called before),
-        until the `upto_token`
+        since the stream was constructed if it hadn't been called before).
 
         Returns:
-            Deferred[Tuple[List[Tuple[int, Any]], int]:
-                Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
-                list of ``(token, row)`` entries. ``row`` will be json-serialised and
-                sent over the replication steam.
+            A triplet `(updates, new_last_token, limited)`, where `updates` is
+            a list of `(token, row)` entries, `new_last_token` is the new
+            position in stream, and `limited` is whether there are more updates
+            to fetch.
         """
-        updates, current_token = await self.get_updates_since(self.last_token)
+        current_token = self.current_token()
+        updates, current_token, limited = await self.get_updates_since(
+            self.last_token, current_token
+        )
         self.last_token = current_token
 
-        return updates, current_token
+        return updates, current_token, limited
 
-    async def get_updates_since(self, from_token):
+    async def get_updates_since(
+        self, from_token: Token, upto_token: Token, limit: int = 100
+    ) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
         """Like get_updates except allows specifying from when we should
         stream updates
 
         Returns:
-            Deferred[Tuple[List[Tuple[int, Any]], int]:
-                Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
-                list of ``(token, row)`` entries. ``row`` will be json-serialised and
-                sent over the replication steam.
+            A triplet `(updates, new_last_token, limited)`, where `updates` is
+            a list of `(token, row)` entries, `new_last_token` is the new
+            position in stream, and `limited` is whether there are more updates
+            to fetch.
         """
-        if from_token in ("NOW", "now"):
-            return [], self.upto_token
-
-        current_token = self.upto_token
 
         from_token = int(from_token)
 
-        if from_token == current_token:
-            return [], current_token
-
-        logger.info("get_updates_since: %s", self.__class__)
-        if self._LIMITED:
-            rows = await self.update_function(
-                from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
-            )
+        if from_token == upto_token:
+            return [], upto_token, False
 
-            # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
-            rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
-        else:
-            rows = await self.update_function(from_token, current_token)
-
-        updates = [(row[0], row[1:]) for row in rows]
-
-        # check we didn't get more rows than the limit.
-        # doing it like this allows the update_function to be a generator.
-        if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
-            raise Exception("stream %s has fallen behind" % (self.NAME))
-
-        return updates, current_token
+        updates, upto_token, limited = await self.update_function(
+            from_token, upto_token, limit=limit,
+        )
+        return updates, upto_token, limited
 
     def current_token(self):
         """Gets the current token of the underlying streams. Should be provided
@@ -223,9 +125,8 @@ class Stream(object):
         """
         raise NotImplementedError()
 
-    def update_function(self, from_token, current_token, limit=None):
-        """Get updates between from_token and to_token. If Stream._LIMITED is
-        True then limit is provided, otherwise it's not.
+    def update_function(self, from_token, current_token, limit):
+        """Get updates between from_token and to_token.
 
         Returns:
             Deferred(list(tuple)): the first entry in the tuple is the token for
@@ -235,52 +136,144 @@ class Stream(object):
         raise NotImplementedError()
 
 
+def db_query_to_update_function(
+    query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
+) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
+    """Wraps a db query function which returns a list of rows to make it
+    suitable for use as an `update_function` for the Stream class
+    """
+
+    async def update_function(from_token, upto_token, limit):
+        rows = await query_function(from_token, upto_token, limit)
+        updates = [(row[0], row[1:]) for row in rows]
+        limited = False
+        if len(updates) == limit:
+            upto_token = rows[-1][0]
+            limited = True
+
+        return updates, upto_token, limited
+
+    return update_function
+
+
+def make_http_update_function(
+    hs, stream_name: str
+) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
+    """Makes a suitable function for use as an `update_function` that queries
+    the master process for updates.
+    """
+
+    client = ReplicationGetStreamUpdates.make_client(hs)
+
+    async def update_function(
+        from_token: int, upto_token: int, limit: int
+    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+        return await client(
+            stream_name=stream_name,
+            from_token=from_token,
+            upto_token=upto_token,
+            limit=limit,
+        )
+
+    return update_function
+
+
 class BackfillStream(Stream):
     """We fetched some old events and either we had never seen that event before
     or it went from being an outlier to not.
     """
 
+    BackfillStreamRow = namedtuple(
+        "BackfillStreamRow",
+        (
+            "event_id",  # str
+            "room_id",  # str
+            "type",  # str
+            "state_key",  # str, optional
+            "redacts",  # str, optional
+            "relates_to",  # str, optional
+        ),
+    )
+
     NAME = "backfill"
     ROW_TYPE = BackfillStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
         self.current_token = store.get_current_backfill_token  # type: ignore
-        self.update_function = store.get_all_new_backfill_event_rows  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows)  # type: ignore
 
         super(BackfillStream, self).__init__(hs)
 
 
 class PresenceStream(Stream):
+    PresenceStreamRow = namedtuple(
+        "PresenceStreamRow",
+        (
+            "user_id",  # str
+            "state",  # str
+            "last_active_ts",  # int
+            "last_federation_update_ts",  # int
+            "last_user_sync_ts",  # int
+            "status_msg",  # str
+            "currently_active",  # bool
+        ),
+    )
+
     NAME = "presence"
-    _LIMITED = False
     ROW_TYPE = PresenceStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
         presence_handler = hs.get_presence_handler()
 
+        self._is_worker = hs.config.worker_app is not None
+
         self.current_token = store.get_current_presence_token  # type: ignore
-        self.update_function = presence_handler.get_all_presence_updates  # type: ignore
+
+        if hs.config.worker_app is None:
+            self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates)  # type: ignore
+        else:
+            # Query master process
+            self.update_function = make_http_update_function(hs, self.NAME)  # type: ignore
 
         super(PresenceStream, self).__init__(hs)
 
 
 class TypingStream(Stream):
+    TypingStreamRow = namedtuple(
+        "TypingStreamRow", ("room_id", "user_ids")  # str  # list(str)
+    )
+
     NAME = "typing"
-    _LIMITED = False
     ROW_TYPE = TypingStreamRow
 
     def __init__(self, hs):
         typing_handler = hs.get_typing_handler()
 
         self.current_token = typing_handler.get_current_token  # type: ignore
-        self.update_function = typing_handler.get_all_typing_updates  # type: ignore
+
+        if hs.config.worker_app is None:
+            self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates)  # type: ignore
+        else:
+            # Query master process
+            self.update_function = make_http_update_function(hs, self.NAME)  # type: ignore
 
         super(TypingStream, self).__init__(hs)
 
 
 class ReceiptsStream(Stream):
+    ReceiptsStreamRow = namedtuple(
+        "ReceiptsStreamRow",
+        (
+            "room_id",  # str
+            "receipt_type",  # str
+            "user_id",  # str
+            "event_id",  # str
+            "data",  # dict
+        ),
+    )
+
     NAME = "receipts"
     ROW_TYPE = ReceiptsStreamRow
 
@@ -288,7 +281,7 @@ class ReceiptsStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_max_receipt_stream_id  # type: ignore
-        self.update_function = store.get_all_updated_receipts  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_updated_receipts)  # type: ignore
 
         super(ReceiptsStream, self).__init__(hs)
 
@@ -297,6 +290,8 @@ class PushRulesStream(Stream):
     """A user has changed their push rules
     """
 
+    PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",))  # str
+
     NAME = "push_rules"
     ROW_TYPE = PushRulesStreamRow
 
@@ -310,13 +305,24 @@ class PushRulesStream(Stream):
 
     async def update_function(self, from_token, to_token, limit):
         rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
-        return [(row[0], row[2]) for row in rows]
+
+        limited = False
+        if len(rows) == limit:
+            to_token = rows[-1][0]
+            limited = True
+
+        return [(row[0], (row[2],)) for row in rows], to_token, limited
 
 
 class PushersStream(Stream):
     """A user has added/changed/removed a pusher
     """
 
+    PushersStreamRow = namedtuple(
+        "PushersStreamRow",
+        ("user_id", "app_id", "pushkey", "deleted"),  # str  # str  # str  # bool
+    )
+
     NAME = "pushers"
     ROW_TYPE = PushersStreamRow
 
@@ -324,7 +330,7 @@ class PushersStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_pushers_stream_token  # type: ignore
-        self.update_function = store.get_all_updated_pushers_rows  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows)  # type: ignore
 
         super(PushersStream, self).__init__(hs)
 
@@ -334,6 +340,21 @@ class CachesStream(Stream):
     the cache on the workers
     """
 
+    @attr.s
+    class CachesStreamRow:
+        """Stream to inform workers they should invalidate their cache.
+
+        Attributes:
+            cache_func: Name of the cached function.
+            keys: The entry in the cache to invalidate. If None then will
+                invalidate all.
+            invalidation_ts: Timestamp of when the invalidation took place.
+        """
+
+        cache_func = attr.ib(type=str)
+        keys = attr.ib(type=Optional[List[Any]])
+        invalidation_ts = attr.ib(type=int)
+
     NAME = "caches"
     ROW_TYPE = CachesStreamRow
 
@@ -341,7 +362,7 @@ class CachesStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_cache_stream_token  # type: ignore
-        self.update_function = store.get_all_updated_caches  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_updated_caches)  # type: ignore
 
         super(CachesStream, self).__init__(hs)
 
@@ -350,6 +371,16 @@ class PublicRoomsStream(Stream):
     """The public rooms list changed
     """
 
+    PublicRoomsStreamRow = namedtuple(
+        "PublicRoomsStreamRow",
+        (
+            "room_id",  # str
+            "visibility",  # str
+            "appservice_id",  # str, optional
+            "network_id",  # str, optional
+        ),
+    )
+
     NAME = "public_rooms"
     ROW_TYPE = PublicRoomsStreamRow
 
@@ -357,24 +388,28 @@ class PublicRoomsStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_current_public_room_stream_id  # type: ignore
-        self.update_function = store.get_all_new_public_rooms  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_new_public_rooms)  # type: ignore
 
         super(PublicRoomsStream, self).__init__(hs)
 
 
 class DeviceListsStream(Stream):
-    """Someone added/changed/removed a device
+    """Either a user has updated their devices or a remote server needs to be
+    told about a device update.
     """
 
+    @attr.s
+    class DeviceListsStreamRow:
+        entity = attr.ib(type=str)
+
     NAME = "device_lists"
-    _LIMITED = False
     ROW_TYPE = DeviceListsStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
 
         self.current_token = store.get_device_stream_token  # type: ignore
-        self.update_function = store.get_all_device_list_changes_for_remotes  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes)  # type: ignore
 
         super(DeviceListsStream, self).__init__(hs)
 
@@ -383,6 +418,8 @@ class ToDeviceStream(Stream):
     """New to_device messages for a client
     """
 
+    ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",))  # str
+
     NAME = "to_device"
     ROW_TYPE = ToDeviceStreamRow
 
@@ -390,7 +427,7 @@ class ToDeviceStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_to_device_stream_token  # type: ignore
-        self.update_function = store.get_all_new_device_messages  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_new_device_messages)  # type: ignore
 
         super(ToDeviceStream, self).__init__(hs)
 
@@ -399,6 +436,10 @@ class TagAccountDataStream(Stream):
     """Someone added/removed a tag for a room
     """
 
+    TagAccountDataStreamRow = namedtuple(
+        "TagAccountDataStreamRow", ("user_id", "room_id", "data")  # str  # str  # dict
+    )
+
     NAME = "tag_account_data"
     ROW_TYPE = TagAccountDataStreamRow
 
@@ -406,7 +447,7 @@ class TagAccountDataStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_max_account_data_stream_id  # type: ignore
-        self.update_function = store.get_all_updated_tags  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_updated_tags)  # type: ignore
 
         super(TagAccountDataStream, self).__init__(hs)
 
@@ -415,6 +456,10 @@ class AccountDataStream(Stream):
     """Global or per room account data was changed
     """
 
+    AccountDataStreamRow = namedtuple(
+        "AccountDataStream", ("user_id", "room_id", "data_type")  # str  # str  # str
+    )
+
     NAME = "account_data"
     ROW_TYPE = AccountDataStreamRow
 
@@ -422,10 +467,11 @@ class AccountDataStream(Stream):
         self.store = hs.get_datastore()
 
         self.current_token = self.store.get_max_account_data_stream_id  # type: ignore
+        self.update_function = db_query_to_update_function(self._update_function)  # type: ignore
 
         super(AccountDataStream, self).__init__(hs)
 
-    async def update_function(self, from_token, to_token, limit):
+    async def _update_function(self, from_token, to_token, limit):
         global_results, room_results = await self.store.get_all_updated_account_data(
             from_token, from_token, to_token, limit
         )
@@ -440,6 +486,11 @@ class AccountDataStream(Stream):
 
 
 class GroupServerStream(Stream):
+    GroupsStreamRow = namedtuple(
+        "GroupsStreamRow",
+        ("group_id", "user_id", "type", "content"),  # str  # str  # str  # dict
+    )
+
     NAME = "groups"
     ROW_TYPE = GroupsStreamRow
 
@@ -447,7 +498,7 @@ class GroupServerStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_group_stream_token  # type: ignore
-        self.update_function = store.get_all_groups_changes  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_groups_changes)  # type: ignore
 
         super(GroupServerStream, self).__init__(hs)
 
@@ -456,14 +507,15 @@ class UserSignatureStream(Stream):
     """A user has signed their own device with their user-signing key
     """
 
+    UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id"))  # str
+
     NAME = "user_signature"
-    _LIMITED = False
     ROW_TYPE = UserSignatureStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
 
         self.current_token = store.get_device_stream_token  # type: ignore
-        self.update_function = store.get_all_user_signature_changes_for_remotes  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes)  # type: ignore
 
         super(UserSignatureStream, self).__init__(hs)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index b3afabb8cd..c6a595629f 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,7 +19,7 @@ from typing import Tuple, Type
 
 import attr
 
-from ._base import Stream
+from ._base import Stream, db_query_to_update_function
 
 
 """Handling of the 'events' replication stream
@@ -117,10 +117,11 @@ class EventsStream(Stream):
     def __init__(self, hs):
         self._store = hs.get_datastore()
         self.current_token = self._store.get_current_events_token  # type: ignore
+        self.update_function = db_query_to_update_function(self._update_function)  # type: ignore
 
         super(EventsStream, self).__init__(hs)
 
-    async def update_function(self, from_token, current_token, limit=None):
+    async def _update_function(self, from_token, current_token, limit=None):
         event_rows = await self._store.get_all_new_forward_event_rows(
             from_token, current_token, limit
         )
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 615f3dc9ac..48c1d45718 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,15 +15,9 @@
 # limitations under the License.
 from collections import namedtuple
 
-from ._base import Stream
+from twisted.internet import defer
 
-FederationStreamRow = namedtuple(
-    "FederationStreamRow",
-    (
-        "type",  # str, the type of data as defined in the BaseFederationRows
-        "data",  # dict, serialization of a federation.send_queue.BaseFederationRow
-    ),
-)
+from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
 
 
 class FederationStream(Stream):
@@ -31,13 +25,28 @@ class FederationStream(Stream):
     sending disabled.
     """
 
+    FederationStreamRow = namedtuple(
+        "FederationStreamRow",
+        (
+            "type",  # str, the type of data as defined in the BaseFederationRows
+            "data",  # dict, serialization of a federation.send_queue.BaseFederationRow
+        ),
+    )
+
     NAME = "federation"
     ROW_TYPE = FederationStreamRow
+    _QUERY_MASTER = True
 
     def __init__(self, hs):
-        federation_sender = hs.get_federation_sender()
-
-        self.current_token = federation_sender.get_current_token  # type: ignore
-        self.update_function = federation_sender.get_replication_rows  # type: ignore
+        # Not all synapse instances will have a federation sender instance,
+        # whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
+        # so we stub the stream out when that is the case.
+        if hs.config.worker_app is None or hs.should_send_federation():
+            federation_sender = hs.get_federation_sender()
+            self.current_token = federation_sender.get_current_token  # type: ignore
+            self.update_function = db_query_to_update_function(federation_sender.get_replication_rows)  # type: ignore
+        else:
+            self.current_token = lambda: 0  # type: ignore
+            self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool))  # type: ignore
 
         super(FederationStream, self).__init__(hs)
diff --git a/synapse/res/templates/sso_auth_confirm.html b/synapse/res/templates/sso_auth_confirm.html
new file mode 100644
index 0000000000..0d9de9d465
--- /dev/null
+++ b/synapse/res/templates/sso_auth_confirm.html
@@ -0,0 +1,14 @@
+<html>
+<head>
+    <title>Authentication</title>
+</head>
+    <body>
+        <div>
+            <p>
+                A client is trying to {{ description | e }}. To confirm this action,
+                <a href="{{ redirect_url | e }}">re-authenticate with single sign-on</a>.
+                If you did not expect this, your account may be compromised!
+            </p>
+        </div>
+    </body>
+</html>
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 4a1fc2ec2b..46e458e95b 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -41,6 +41,7 @@ from synapse.rest.client.v2_alpha import (
     keys,
     notifications,
     openid,
+    password_policy,
     read_marker,
     receipts,
     register,
@@ -118,6 +119,7 @@ class ClientRestResource(JsonResource):
         capabilities.register_servlets(hs, client_resource)
         account_validity.register_servlets(hs, client_resource)
         relations.register_servlets(hs, client_resource)
+        password_policy.register_servlets(hs, client_resource)
 
         # moving to /_synapse/admin
         synapse.rest.admin.register_servlets_for_client_rest_resource(
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 42cc2b062a..ed70d448a1 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -29,7 +29,11 @@ from synapse.rest.admin._base import (
 from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
 from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
 from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
-from synapse.rest.admin.rooms import ListRoomRestServlet, ShutdownRoomRestServlet
+from synapse.rest.admin.rooms import (
+    JoinRoomAliasServlet,
+    ListRoomRestServlet,
+    ShutdownRoomRestServlet,
+)
 from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
 from synapse.rest.admin.users import (
     AccountValidityRenewServlet,
@@ -189,6 +193,7 @@ def register_servlets(hs, http_server):
     """
     register_servlets_for_client_rest_resource(hs, http_server)
     ListRoomRestServlet(hs).register(http_server)
+    JoinRoomAliasServlet(hs).register(http_server)
     PurgeRoomServlet(hs).register(http_server)
     SendServerNoticeServlet(hs).register(http_server)
     VersionServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index f9b8c0a4f0..659b8a10ee 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -13,9 +13,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import List, Optional
 
-from synapse.api.constants import Membership
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.errors import Codes, NotFoundError, SynapseError
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
@@ -29,7 +30,7 @@ from synapse.rest.admin._base import (
     historical_admin_path_patterns,
 )
 from synapse.storage.data_stores.main.room import RoomSortOrder
-from synapse.types import create_requester
+from synapse.types import RoomAlias, RoomID, UserID, create_requester
 from synapse.util.async_helpers import maybe_awaitable
 
 logger = logging.getLogger(__name__)
@@ -237,3 +238,75 @@ class ListRoomRestServlet(RestServlet):
                 response["prev_batch"] = 0
 
         return 200, response
+
+
+class JoinRoomAliasServlet(RestServlet):
+
+    PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
+
+    def __init__(self, hs):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.room_member_handler = hs.get_room_member_handler()
+        self.admin_handler = hs.get_handlers().admin_handler
+        self.state_handler = hs.get_state_handler()
+
+    async def on_POST(self, request, room_identifier):
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
+
+        content = parse_json_object_from_request(request)
+
+        assert_params_in_dict(content, ["user_id"])
+        target_user = UserID.from_string(content["user_id"])
+
+        if not self.hs.is_mine(target_user):
+            raise SynapseError(400, "This endpoint can only be used with local users")
+
+        if not await self.admin_handler.get_user(target_user):
+            raise NotFoundError("User not found")
+
+        if RoomID.is_valid(room_identifier):
+            room_id = room_identifier
+            try:
+                remote_room_hosts = [
+                    x.decode("ascii") for x in request.args[b"server_name"]
+                ]  # type: Optional[List[str]]
+            except Exception:
+                remote_room_hosts = None
+        elif RoomAlias.is_valid(room_identifier):
+            handler = self.room_member_handler
+            room_alias = RoomAlias.from_string(room_identifier)
+            room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
+            room_id = room_id.to_string()
+        else:
+            raise SynapseError(
+                400, "%s was not legal room ID or room alias" % (room_identifier,)
+            )
+
+        fake_requester = create_requester(target_user)
+
+        # send invite if room has "JoinRules.INVITE"
+        room_state = await self.state_handler.get_current_state(room_id)
+        join_rules_event = room_state.get((EventTypes.JoinRules, ""))
+        if join_rules_event:
+            if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
+                await self.room_member_handler.update_membership(
+                    requester=requester,
+                    target=fake_requester.user,
+                    room_id=room_id,
+                    action="invite",
+                    remote_room_hosts=remote_room_hosts,
+                    ratelimit=False,
+                )
+
+        await self.room_member_handler.update_membership(
+            requester=fake_requester,
+            target=fake_requester.user,
+            room_id=room_id,
+            action="join",
+            remote_room_hosts=remote_room_hosts,
+            ratelimit=False,
+        )
+
+        return 200, {"room_id": room_id}
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index d0d4999795..59593cbf6e 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,11 +14,6 @@
 # limitations under the License.
 
 import logging
-import xml.etree.ElementTree as ET
-
-from six.moves import urllib
-
-from twisted.web.client import PartialDownloadError
 
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
@@ -28,10 +23,10 @@ from synapse.http.servlet import (
     parse_json_object_from_request,
     parse_string,
 )
-from synapse.push.mailer import load_jinja2_templates
+from synapse.http.site import SynapseRequest
 from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.rest.well_known import WellKnownBuilder
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import UserID
 from synapse.util.msisdn import phone_number_to_msisdn
 
 logger = logging.getLogger(__name__)
@@ -402,7 +397,7 @@ class BaseSSORedirectServlet(RestServlet):
 
     PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
 
-    def on_GET(self, request):
+    def on_GET(self, request: SynapseRequest):
         args = request.args
         if b"redirectUrl" not in args:
             return 400, "Redirect URL not specified for SSO auth"
@@ -411,15 +406,15 @@ class BaseSSORedirectServlet(RestServlet):
         request.redirect(sso_url)
         finish_request(request)
 
-    def get_sso_url(self, client_redirect_url):
+    def get_sso_url(self, client_redirect_url: bytes) -> bytes:
         """Get the URL to redirect to, to perform SSO auth
 
         Args:
-            client_redirect_url (bytes): the URL that we should redirect the
+            client_redirect_url: the URL that we should redirect the
                 client to when everything is done
 
         Returns:
-            bytes: URL to redirect to
+            URL to redirect to
         """
         # to be implemented by subclasses
         raise NotImplementedError()
@@ -427,19 +422,10 @@ class BaseSSORedirectServlet(RestServlet):
 
 class CasRedirectServlet(BaseSSORedirectServlet):
     def __init__(self, hs):
-        super(CasRedirectServlet, self).__init__()
-        self.cas_server_url = hs.config.cas_server_url.encode("ascii")
-        self.cas_service_url = hs.config.cas_service_url.encode("ascii")
+        self._cas_handler = hs.get_cas_handler()
 
-    def get_sso_url(self, client_redirect_url):
-        client_redirect_url_param = urllib.parse.urlencode(
-            {b"redirectUrl": client_redirect_url}
-        ).encode("ascii")
-        hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket"
-        service_param = urllib.parse.urlencode(
-            {b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)}
-        ).encode("ascii")
-        return b"%s/login?%s" % (self.cas_server_url, service_param)
+    def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+        return self._cas_handler.handle_redirect_request(client_redirect_url)
 
 
 class CasTicketServlet(RestServlet):
@@ -447,81 +433,15 @@ class CasTicketServlet(RestServlet):
 
     def __init__(self, hs):
         super(CasTicketServlet, self).__init__()
-        self.cas_server_url = hs.config.cas_server_url
-        self.cas_service_url = hs.config.cas_service_url
-        self.cas_displayname_attribute = hs.config.cas_displayname_attribute
-        self.cas_required_attributes = hs.config.cas_required_attributes
-        self._sso_auth_handler = SSOAuthHandler(hs)
-        self._http_client = hs.get_proxied_http_client()
-
-    async def on_GET(self, request):
-        client_redirect_url = parse_string(request, "redirectUrl", required=True)
-        uri = self.cas_server_url + "/proxyValidate"
-        args = {
-            "ticket": parse_string(request, "ticket", required=True),
-            "service": self.cas_service_url,
-        }
-        try:
-            body = await self._http_client.get_raw(uri, args)
-        except PartialDownloadError as pde:
-            # Twisted raises this error if the connection is closed,
-            # even if that's being used old-http style to signal end-of-data
-            body = pde.response
-        result = await self.handle_cas_response(request, body, client_redirect_url)
-        return result
-
-    def handle_cas_response(self, request, cas_response_body, client_redirect_url):
-        user, attributes = self.parse_cas_response(cas_response_body)
-        displayname = attributes.pop(self.cas_displayname_attribute, None)
-
-        for required_attribute, required_value in self.cas_required_attributes.items():
-            # If required attribute was not in CAS Response - Forbidden
-            if required_attribute not in attributes:
-                raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
-            # Also need to check value
-            if required_value is not None:
-                actual_value = attributes[required_attribute]
-                # If required attribute value does not match expected - Forbidden
-                if required_value != actual_value:
-                    raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+        self._cas_handler = hs.get_cas_handler()
 
-        return self._sso_auth_handler.on_successful_auth(
-            user, request, client_redirect_url, displayname
+    async def on_GET(self, request: SynapseRequest) -> None:
+        client_redirect_url = parse_string(request, "redirectUrl", required=True)
+        ticket = parse_string(request, "ticket", required=True)
+        await self._cas_handler.handle_ticket_request(
+            request, client_redirect_url, ticket
         )
 
-    def parse_cas_response(self, cas_response_body):
-        user = None
-        attributes = {}
-        try:
-            root = ET.fromstring(cas_response_body)
-            if not root.tag.endswith("serviceResponse"):
-                raise Exception("root of CAS response is not serviceResponse")
-            success = root[0].tag.endswith("authenticationSuccess")
-            for child in root[0]:
-                if child.tag.endswith("user"):
-                    user = child.text
-                if child.tag.endswith("attributes"):
-                    for attribute in child:
-                        # ElementTree library expands the namespace in
-                        # attribute tags to the full URL of the namespace.
-                        # We don't care about namespace here and it will always
-                        # be encased in curly braces, so we remove them.
-                        tag = attribute.tag
-                        if "}" in tag:
-                            tag = tag.split("}")[1]
-                        attributes[tag] = attribute.text
-            if user is None:
-                raise Exception("CAS response does not contain user")
-        except Exception:
-            logger.exception("Error parsing CAS response")
-            raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
-        if not success:
-            raise LoginError(
-                401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
-            )
-        return user, attributes
-
 
 class SAMLRedirectServlet(BaseSSORedirectServlet):
     PATTERNS = client_patterns("/login/sso/redirect", v1=True)
@@ -529,72 +449,10 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
     def __init__(self, hs):
         self._saml_handler = hs.get_saml_handler()
 
-    def get_sso_url(self, client_redirect_url):
+    def get_sso_url(self, client_redirect_url: bytes) -> bytes:
         return self._saml_handler.handle_redirect_request(client_redirect_url)
 
 
-class SSOAuthHandler(object):
-    """
-    Utility class for Resources and Servlets which handle the response from a SSO
-    service
-
-    Args:
-        hs (synapse.server.HomeServer)
-    """
-
-    def __init__(self, hs):
-        self._hostname = hs.hostname
-        self._auth_handler = hs.get_auth_handler()
-        self._registration_handler = hs.get_registration_handler()
-        self._macaroon_gen = hs.get_macaroon_generator()
-
-        # Load the redirect page HTML template
-        self._template = load_jinja2_templates(
-            hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
-        )[0]
-
-        self._server_name = hs.config.server_name
-
-        # cast to tuple for use with str.startswith
-        self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
-
-    async def on_successful_auth(
-        self, username, request, client_redirect_url, user_display_name=None
-    ):
-        """Called once the user has successfully authenticated with the SSO.
-
-        Registers the user if necessary, and then returns a redirect (with
-        a login token) to the client.
-
-        Args:
-            username (unicode|bytes): the remote user id. We'll map this onto
-                something sane for a MXID localpath.
-
-            request (SynapseRequest): the incoming request from the browser. We'll
-                respond to it with a redirect.
-
-            client_redirect_url (unicode): the redirect_url the client gave us when
-                it first started the process.
-
-            user_display_name (unicode|None): if set, and we have to register a new user,
-                we will set their displayname to this.
-
-        Returns:
-            Deferred[none]: Completes once we have handled the request.
-        """
-        localpart = map_username_to_mxid_localpart(username)
-        user_id = UserID(localpart, self._hostname).to_string()
-        registered_user_id = await self._auth_handler.check_user_exists(user_id)
-        if not registered_user_id:
-            registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart, default_display_name=user_display_name
-            )
-
-        self._auth_handler.complete_sso_login(
-            registered_user_id, request, client_redirect_url
-        )
-
-
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
     if hs.config.cas_enabled:
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 631cc74cb4..31435b1e1c 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -234,13 +234,21 @@ class PasswordRestServlet(RestServlet):
         if self.auth.has_access_token(request):
             requester = await self.auth.get_user_by_req(request)
             params = await self.auth_handler.validate_user_via_ui_auth(
-                requester, body, self.hs.get_ip_from_request(request)
+                requester,
+                request,
+                body,
+                self.hs.get_ip_from_request(request),
+                "modify your account password",
             )
             user_id = requester.user.to_string()
         else:
             requester = None
             result, params, _ = await self.auth_handler.check_auth(
-                [[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request)
+                [[LoginType.EMAIL_IDENTITY]],
+                request,
+                body,
+                self.hs.get_ip_from_request(request),
+                "modify your account password",
             )
 
             if LoginType.EMAIL_IDENTITY in result:
@@ -308,7 +316,11 @@ class DeactivateAccountRestServlet(RestServlet):
             return 200, {}
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester, body, self.hs.get_ip_from_request(request)
+            requester,
+            request,
+            body,
+            self.hs.get_ip_from_request(request),
+            "deactivate your account",
         )
         result = await self._deactivate_account_handler.deactivate_account(
             requester.user.to_string(), erase, id_server=body.get("id_server")
@@ -602,6 +614,11 @@ class ThreepidRestServlet(RestServlet):
         return 200, {"threepids": threepids}
 
     async def on_POST(self, request):
+        if not self.hs.config.enable_3pid_changes:
+            raise SynapseError(
+                400, "3PID changes are disabled on this server", Codes.FORBIDDEN
+            )
+
         requester = await self.auth.get_user_by_req(request)
         user_id = requester.user.to_string()
         body = parse_json_object_from_request(request)
@@ -646,6 +663,11 @@ class ThreepidAddRestServlet(RestServlet):
 
     @interactive_auth_handler
     async def on_POST(self, request):
+        if not self.hs.config.enable_3pid_changes:
+            raise SynapseError(
+                400, "3PID changes are disabled on this server", Codes.FORBIDDEN
+            )
+
         requester = await self.auth.get_user_by_req(request)
         user_id = requester.user.to_string()
         body = parse_json_object_from_request(request)
@@ -656,7 +678,11 @@ class ThreepidAddRestServlet(RestServlet):
         assert_valid_client_secret(client_secret)
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester, body, self.hs.get_ip_from_request(request)
+            requester,
+            request,
+            body,
+            self.hs.get_ip_from_request(request),
+            "add a third-party identifier to your account",
         )
 
         validation_session = await self.identity_handler.validate_threepid_session(
@@ -741,10 +767,16 @@ class ThreepidDeleteRestServlet(RestServlet):
 
     def __init__(self, hs):
         super(ThreepidDeleteRestServlet, self).__init__()
+        self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
 
     async def on_POST(self, request):
+        if not self.hs.config.enable_3pid_changes:
+            raise SynapseError(
+                400, "3PID changes are disabled on this server", Codes.FORBIDDEN
+            )
+
         body = parse_json_object_from_request(request)
         assert_params_in_dict(body, ["medium", "address"])
 
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 50e080673b..1787562b90 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -18,6 +18,7 @@ import logging
 from synapse.api.constants import LoginType
 from synapse.api.errors import SynapseError
 from synapse.api.urls import CLIENT_API_PREFIX
+from synapse.handlers.auth import SUCCESS_TEMPLATE
 from synapse.http.server import finish_request
 from synapse.http.servlet import RestServlet, parse_string
 
@@ -89,30 +90,6 @@ TERMS_TEMPLATE = """
 </html>
 """
 
-SUCCESS_TEMPLATE = """
-<html>
-<head>
-<title>Success!</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
-    user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
-<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
-<script>
-if (window.onAuthDone) {
-    window.onAuthDone();
-} else if (window.opener && window.opener.postMessage) {
-     window.opener.postMessage("authDone", "*");
-}
-</script>
-</head>
-<body>
-    <div>
-        <p>Thank you</p>
-        <p>You may now close this window and return to the application</p>
-    </div>
-</body>
-</html>
-"""
-
 
 class AuthRestServlet(RestServlet):
     """
@@ -130,6 +107,11 @@ class AuthRestServlet(RestServlet):
         self.auth_handler = hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
 
+        # SSO configuration.
+        self._saml_enabled = hs.config.saml2_enabled
+        if self._saml_enabled:
+            self._saml_handler = hs.get_saml_handler()
+
     def on_GET(self, request, stagetype):
         session = parse_string(request, "session")
         if not session:
@@ -142,14 +124,6 @@ class AuthRestServlet(RestServlet):
                 % (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
                 "sitekey": self.hs.config.recaptcha_public_key,
             }
-            html_bytes = html.encode("utf8")
-            request.setResponseCode(200)
-            request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
-            request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
-
-            request.write(html_bytes)
-            finish_request(request)
-            return None
         elif stagetype == LoginType.TERMS:
             html = TERMS_TEMPLATE % {
                 "session": session,
@@ -158,17 +132,28 @@ class AuthRestServlet(RestServlet):
                 "myurl": "%s/r0/auth/%s/fallback/web"
                 % (CLIENT_API_PREFIX, LoginType.TERMS),
             }
-            html_bytes = html.encode("utf8")
-            request.setResponseCode(200)
-            request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
-            request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
-
-            request.write(html_bytes)
-            finish_request(request)
-            return None
+
+        elif stagetype == LoginType.SSO and self._saml_enabled:
+            # Display a confirmation page which prompts the user to
+            # re-authenticate with their SSO provider.
+            client_redirect_url = ""
+            sso_redirect_url = self._saml_handler.handle_redirect_request(
+                client_redirect_url, session
+            )
+            html = self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
         else:
             raise SynapseError(404, "Unknown auth stage type")
 
+        # Render the HTML and return.
+        html_bytes = html.encode("utf8")
+        request.setResponseCode(200)
+        request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+        request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
+
+        request.write(html_bytes)
+        finish_request(request)
+        return None
+
     async def on_POST(self, request, stagetype):
 
         session = parse_string(request, "session")
@@ -196,15 +181,6 @@ class AuthRestServlet(RestServlet):
                     % (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
                     "sitekey": self.hs.config.recaptcha_public_key,
                 }
-            html_bytes = html.encode("utf8")
-            request.setResponseCode(200)
-            request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
-            request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
-
-            request.write(html_bytes)
-            finish_request(request)
-
-            return None
         elif stagetype == LoginType.TERMS:
             authdict = {"session": session}
 
@@ -225,17 +201,22 @@ class AuthRestServlet(RestServlet):
                     "myurl": "%s/r0/auth/%s/fallback/web"
                     % (CLIENT_API_PREFIX, LoginType.TERMS),
                 }
-            html_bytes = html.encode("utf8")
-            request.setResponseCode(200)
-            request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
-            request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
-
-            request.write(html_bytes)
-            finish_request(request)
-            return None
+        elif stagetype == LoginType.SSO:
+            # The SSO fallback workflow should not post here,
+            raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
         else:
             raise SynapseError(404, "Unknown auth stage type")
 
+        # Render the HTML and return.
+        html_bytes = html.encode("utf8")
+        request.setResponseCode(200)
+        request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+        request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
+
+        request.write(html_bytes)
+        finish_request(request)
+        return None
+
     def on_OPTIONS(self, _):
         return 200, {}
 
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 94ff73f384..c0714fcfb1 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -81,7 +81,11 @@ class DeleteDevicesRestServlet(RestServlet):
         assert_params_in_dict(body, ["devices"])
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester, body, self.hs.get_ip_from_request(request)
+            requester,
+            request,
+            body,
+            self.hs.get_ip_from_request(request),
+            "remove device(s) from your account",
         )
 
         await self.device_handler.delete_devices(
@@ -127,7 +131,11 @@ class DeviceRestServlet(RestServlet):
                 raise
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester, body, self.hs.get_ip_from_request(request)
+            requester,
+            request,
+            body,
+            self.hs.get_ip_from_request(request),
+            "remove a device from your account",
         )
 
         await self.device_handler.delete_device(requester.user.to_string(), device_id)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index f7ed4daf90..8f41a3edbf 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -263,7 +263,11 @@ class SigningKeyUploadServlet(RestServlet):
         body = parse_json_object_from_request(request)
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester, body, self.hs.get_ip_from_request(request)
+            requester,
+            request,
+            body,
+            self.hs.get_ip_from_request(request),
+            "add a device signing key to your account",
         )
 
         result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py
new file mode 100644
index 0000000000..968403cca4
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/password_policy.py
@@ -0,0 +1,58 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.http.servlet import RestServlet
+
+from ._base import client_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class PasswordPolicyServlet(RestServlet):
+    PATTERNS = client_patterns("/password_policy$")
+
+    def __init__(self, hs):
+        """
+        Args:
+            hs (synapse.server.HomeServer): server
+        """
+        super(PasswordPolicyServlet, self).__init__()
+
+        self.policy = hs.config.password_policy
+        self.enabled = hs.config.password_policy_enabled
+
+    def on_GET(self, request):
+        if not self.enabled or not self.policy:
+            return (200, {})
+
+        policy = {}
+
+        for param in [
+            "minimum_length",
+            "require_digit",
+            "require_symbol",
+            "require_lowercase",
+            "require_uppercase",
+        ]:
+            if param in self.policy:
+                policy["m.%s" % param] = self.policy[param]
+
+        return (200, policy)
+
+
+def register_servlets(hs, http_server):
+    PasswordPolicyServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index a09189b1b4..431ecf4f84 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -373,6 +373,7 @@ class RegisterRestServlet(RestServlet):
         self.room_member_handler = hs.get_room_member_handler()
         self.macaroon_gen = hs.get_macaroon_generator()
         self.ratelimiter = hs.get_registration_ratelimiter()
+        self.password_policy_handler = hs.get_password_policy_handler()
         self.clock = hs.get_clock()
 
         self._registration_flows = _calculate_registration_flows(
@@ -420,6 +421,7 @@ class RegisterRestServlet(RestServlet):
                 or len(body["password"]) > 512
             ):
                 raise SynapseError(400, "Invalid password")
+            self.password_policy_handler.validate_password(body["password"])
 
         desired_username = None
         if "username" in body:
@@ -499,7 +501,11 @@ class RegisterRestServlet(RestServlet):
             )
 
         auth_result, params, session_id = await self.auth_handler.check_auth(
-            self._registration_flows, body, self.hs.get_ip_from_request(request)
+            self._registration_flows,
+            request,
+            body,
+            self.hs.get_ip_from_request(request),
+            "register a new account",
         )
 
         # Check that we're not trying to register a denied 3pid.
diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py
index 38952a1d27..59529707df 100644
--- a/synapse/rest/client/v2_alpha/room_keys.py
+++ b/synapse/rest/client/v2_alpha/room_keys.py
@@ -188,7 +188,7 @@ class RoomKeysServlet(RestServlet):
         """
         requester = await self.auth.get_user_by_req(request, allow_guest=False)
         user_id = requester.user.to_string()
-        version = parse_string(request, "version")
+        version = parse_string(request, "version", required=True)
 
         room_keys = await self.e2e_room_keys_handler.get_room_keys(
             user_id, version, room_id, session_id
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 66a01559e1..24d3ae5bbc 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -50,6 +50,9 @@ class DownloadResource(DirectServeResource):
             b" media-src 'self';"
             b" object-src 'self';",
         )
+        request.setHeader(
+            b"Referrer-Policy", b"no-referrer",
+        )
         server_name, media_id, name = parse_media_id(request)
         if server_name == self.server_name:
             await self.media_repo.get_local_media(request, media_id, name)
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 490b1b45a8..fd10d42f2f 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -24,7 +24,6 @@ from six import iteritems
 
 import twisted.internet.error
 import twisted.web.http
-from twisted.internet import defer
 from twisted.web.resource import Resource
 
 from synapse.api.errors import (
@@ -114,15 +113,14 @@ class MediaRepository(object):
             "update_recently_accessed_media", self._update_recently_accessed
         )
 
-    @defer.inlineCallbacks
-    def _update_recently_accessed(self):
+    async def _update_recently_accessed(self):
         remote_media = self.recently_accessed_remotes
         self.recently_accessed_remotes = set()
 
         local_media = self.recently_accessed_locals
         self.recently_accessed_locals = set()
 
-        yield self.store.update_cached_last_access_time(
+        await self.store.update_cached_last_access_time(
             local_media, remote_media, self.clock.time_msec()
         )
 
@@ -138,8 +136,7 @@ class MediaRepository(object):
         else:
             self.recently_accessed_locals.add(media_id)
 
-    @defer.inlineCallbacks
-    def create_content(
+    async def create_content(
         self, media_type, upload_name, content, content_length, auth_user
     ):
         """Store uploaded content for a local user and return the mxc URL
@@ -158,11 +155,11 @@ class MediaRepository(object):
 
         file_info = FileInfo(server_name=None, file_id=media_id)
 
-        fname = yield self.media_storage.store_file(content, file_info)
+        fname = await self.media_storage.store_file(content, file_info)
 
         logger.info("Stored local media in file %r", fname)
 
-        yield self.store.store_local_media(
+        await self.store.store_local_media(
             media_id=media_id,
             media_type=media_type,
             time_now_ms=self.clock.time_msec(),
@@ -171,12 +168,11 @@ class MediaRepository(object):
             user_id=auth_user,
         )
 
-        yield self._generate_thumbnails(None, media_id, media_id, media_type)
+        await self._generate_thumbnails(None, media_id, media_id, media_type)
 
         return "mxc://%s/%s" % (self.server_name, media_id)
 
-    @defer.inlineCallbacks
-    def get_local_media(self, request, media_id, name):
+    async def get_local_media(self, request, media_id, name):
         """Responds to reqests for local media, if exists, or returns 404.
 
         Args:
@@ -190,7 +186,7 @@ class MediaRepository(object):
             Deferred: Resolves once a response has successfully been written
                 to request
         """
-        media_info = yield self.store.get_local_media(media_id)
+        media_info = await self.store.get_local_media(media_id)
         if not media_info or media_info["quarantined_by"]:
             respond_404(request)
             return
@@ -204,13 +200,12 @@ class MediaRepository(object):
 
         file_info = FileInfo(None, media_id, url_cache=url_cache)
 
-        responder = yield self.media_storage.fetch_media(file_info)
-        yield respond_with_responder(
+        responder = await self.media_storage.fetch_media(file_info)
+        await respond_with_responder(
             request, responder, media_type, media_length, upload_name
         )
 
-    @defer.inlineCallbacks
-    def get_remote_media(self, request, server_name, media_id, name):
+    async def get_remote_media(self, request, server_name, media_id, name):
         """Respond to requests for remote media.
 
         Args:
@@ -236,8 +231,8 @@ class MediaRepository(object):
         # We linearize here to ensure that we don't try and download remote
         # media multiple times concurrently
         key = (server_name, media_id)
-        with (yield self.remote_media_linearizer.queue(key)):
-            responder, media_info = yield self._get_remote_media_impl(
+        with (await self.remote_media_linearizer.queue(key)):
+            responder, media_info = await self._get_remote_media_impl(
                 server_name, media_id
             )
 
@@ -246,14 +241,13 @@ class MediaRepository(object):
             media_type = media_info["media_type"]
             media_length = media_info["media_length"]
             upload_name = name if name else media_info["upload_name"]
-            yield respond_with_responder(
+            await respond_with_responder(
                 request, responder, media_type, media_length, upload_name
             )
         else:
             respond_404(request)
 
-    @defer.inlineCallbacks
-    def get_remote_media_info(self, server_name, media_id):
+    async def get_remote_media_info(self, server_name, media_id):
         """Gets the media info associated with the remote file, downloading
         if necessary.
 
@@ -274,8 +268,8 @@ class MediaRepository(object):
         # We linearize here to ensure that we don't try and download remote
         # media multiple times concurrently
         key = (server_name, media_id)
-        with (yield self.remote_media_linearizer.queue(key)):
-            responder, media_info = yield self._get_remote_media_impl(
+        with (await self.remote_media_linearizer.queue(key)):
+            responder, media_info = await self._get_remote_media_impl(
                 server_name, media_id
             )
 
@@ -286,8 +280,7 @@ class MediaRepository(object):
 
         return media_info
 
-    @defer.inlineCallbacks
-    def _get_remote_media_impl(self, server_name, media_id):
+    async def _get_remote_media_impl(self, server_name, media_id):
         """Looks for media in local cache, if not there then attempt to
         download from remote server.
 
@@ -299,7 +292,7 @@ class MediaRepository(object):
         Returns:
             Deferred[(Responder, media_info)]
         """
-        media_info = yield self.store.get_cached_remote_media(server_name, media_id)
+        media_info = await self.store.get_cached_remote_media(server_name, media_id)
 
         # file_id is the ID we use to track the file locally. If we've already
         # seen the file then reuse the existing ID, otherwise genereate a new
@@ -317,19 +310,18 @@ class MediaRepository(object):
                 logger.info("Media is quarantined")
                 raise NotFoundError()
 
-            responder = yield self.media_storage.fetch_media(file_info)
+            responder = await self.media_storage.fetch_media(file_info)
             if responder:
                 return responder, media_info
 
         # Failed to find the file anywhere, lets download it.
 
-        media_info = yield self._download_remote_file(server_name, media_id, file_id)
+        media_info = await self._download_remote_file(server_name, media_id, file_id)
 
-        responder = yield self.media_storage.fetch_media(file_info)
+        responder = await self.media_storage.fetch_media(file_info)
         return responder, media_info
 
-    @defer.inlineCallbacks
-    def _download_remote_file(self, server_name, media_id, file_id):
+    async def _download_remote_file(self, server_name, media_id, file_id):
         """Attempt to download the remote file from the given server name,
         using the given file_id as the local id.
 
@@ -351,7 +343,7 @@ class MediaRepository(object):
                 ("/_matrix/media/v1/download", server_name, media_id)
             )
             try:
-                length, headers = yield self.client.get_file(
+                length, headers = await self.client.get_file(
                     server_name,
                     request_path,
                     output_stream=f,
@@ -397,7 +389,7 @@ class MediaRepository(object):
                 )
                 raise SynapseError(502, "Failed to fetch remote media")
 
-            yield finish()
+            await finish()
 
         media_type = headers[b"Content-Type"][0].decode("ascii")
         upload_name = get_filename_from_headers(headers)
@@ -405,7 +397,7 @@ class MediaRepository(object):
 
         logger.info("Stored remote media in file %r", fname)
 
-        yield self.store.store_cached_remote_media(
+        await self.store.store_cached_remote_media(
             origin=server_name,
             media_id=media_id,
             media_type=media_type,
@@ -423,7 +415,7 @@ class MediaRepository(object):
             "filesystem_id": file_id,
         }
 
-        yield self._generate_thumbnails(server_name, media_id, file_id, media_type)
+        await self._generate_thumbnails(server_name, media_id, file_id, media_type)
 
         return media_info
 
@@ -458,16 +450,15 @@ class MediaRepository(object):
 
         return t_byte_source
 
-    @defer.inlineCallbacks
-    def generate_local_exact_thumbnail(
+    async def generate_local_exact_thumbnail(
         self, media_id, t_width, t_height, t_method, t_type, url_cache
     ):
-        input_path = yield self.media_storage.ensure_media_is_in_local_cache(
+        input_path = await self.media_storage.ensure_media_is_in_local_cache(
             FileInfo(None, media_id, url_cache=url_cache)
         )
 
         thumbnailer = Thumbnailer(input_path)
-        t_byte_source = yield defer_to_thread(
+        t_byte_source = await defer_to_thread(
             self.hs.get_reactor(),
             self._generate_thumbnail,
             thumbnailer,
@@ -490,7 +481,7 @@ class MediaRepository(object):
                     thumbnail_type=t_type,
                 )
 
-                output_path = yield self.media_storage.store_file(
+                output_path = await self.media_storage.store_file(
                     t_byte_source, file_info
                 )
             finally:
@@ -500,22 +491,21 @@ class MediaRepository(object):
 
             t_len = os.path.getsize(output_path)
 
-            yield self.store.store_local_thumbnail(
+            await self.store.store_local_thumbnail(
                 media_id, t_width, t_height, t_type, t_method, t_len
             )
 
             return output_path
 
-    @defer.inlineCallbacks
-    def generate_remote_exact_thumbnail(
+    async def generate_remote_exact_thumbnail(
         self, server_name, file_id, media_id, t_width, t_height, t_method, t_type
     ):
-        input_path = yield self.media_storage.ensure_media_is_in_local_cache(
+        input_path = await self.media_storage.ensure_media_is_in_local_cache(
             FileInfo(server_name, file_id, url_cache=False)
         )
 
         thumbnailer = Thumbnailer(input_path)
-        t_byte_source = yield defer_to_thread(
+        t_byte_source = await defer_to_thread(
             self.hs.get_reactor(),
             self._generate_thumbnail,
             thumbnailer,
@@ -537,7 +527,7 @@ class MediaRepository(object):
                     thumbnail_type=t_type,
                 )
 
-                output_path = yield self.media_storage.store_file(
+                output_path = await self.media_storage.store_file(
                     t_byte_source, file_info
                 )
             finally:
@@ -547,7 +537,7 @@ class MediaRepository(object):
 
             t_len = os.path.getsize(output_path)
 
-            yield self.store.store_remote_media_thumbnail(
+            await self.store.store_remote_media_thumbnail(
                 server_name,
                 media_id,
                 file_id,
@@ -560,8 +550,7 @@ class MediaRepository(object):
 
             return output_path
 
-    @defer.inlineCallbacks
-    def _generate_thumbnails(
+    async def _generate_thumbnails(
         self, server_name, media_id, file_id, media_type, url_cache=False
     ):
         """Generate and store thumbnails for an image.
@@ -582,7 +571,7 @@ class MediaRepository(object):
         if not requirements:
             return
 
-        input_path = yield self.media_storage.ensure_media_is_in_local_cache(
+        input_path = await self.media_storage.ensure_media_is_in_local_cache(
             FileInfo(server_name, file_id, url_cache=url_cache)
         )
 
@@ -600,7 +589,7 @@ class MediaRepository(object):
             return
 
         if thumbnailer.transpose_method is not None:
-            m_width, m_height = yield defer_to_thread(
+            m_width, m_height = await defer_to_thread(
                 self.hs.get_reactor(), thumbnailer.transpose
             )
 
@@ -620,11 +609,11 @@ class MediaRepository(object):
         for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
             # Generate the thumbnail
             if t_method == "crop":
-                t_byte_source = yield defer_to_thread(
+                t_byte_source = await defer_to_thread(
                     self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type
                 )
             elif t_method == "scale":
-                t_byte_source = yield defer_to_thread(
+                t_byte_source = await defer_to_thread(
                     self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type
                 )
             else:
@@ -646,7 +635,7 @@ class MediaRepository(object):
                     url_cache=url_cache,
                 )
 
-                output_path = yield self.media_storage.store_file(
+                output_path = await self.media_storage.store_file(
                     t_byte_source, file_info
                 )
             finally:
@@ -656,7 +645,7 @@ class MediaRepository(object):
 
             # Write to database
             if server_name:
-                yield self.store.store_remote_media_thumbnail(
+                await self.store.store_remote_media_thumbnail(
                     server_name,
                     media_id,
                     file_id,
@@ -667,15 +656,14 @@ class MediaRepository(object):
                     t_len,
                 )
             else:
-                yield self.store.store_local_thumbnail(
+                await self.store.store_local_thumbnail(
                     media_id, t_width, t_height, t_type, t_method, t_len
                 )
 
         return {"width": m_width, "height": m_height}
 
-    @defer.inlineCallbacks
-    def delete_old_remote_media(self, before_ts):
-        old_media = yield self.store.get_remote_media_before(before_ts)
+    async def delete_old_remote_media(self, before_ts):
+        old_media = await self.store.get_remote_media_before(before_ts)
 
         deleted = 0
 
@@ -689,7 +677,7 @@ class MediaRepository(object):
 
             # TODO: Should we delete from the backup store
 
-            with (yield self.remote_media_linearizer.queue(key)):
+            with (await self.remote_media_linearizer.queue(key)):
                 full_path = self.filepaths.remote_media_filepath(origin, file_id)
                 try:
                     os.remove(full_path)
@@ -705,7 +693,7 @@ class MediaRepository(object):
                 )
                 shutil.rmtree(thumbnail_dir, ignore_errors=True)
 
-                yield self.store.delete_remote_media(origin, media_id)
+                await self.store.delete_remote_media(origin, media_id)
                 deleted += 1
 
         return {"deleted": deleted}
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 07e395cfd1..c46676f8fc 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -165,8 +165,7 @@ class PreviewUrlResource(DirectServeResource):
         og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
         respond_with_json_bytes(request, 200, og, send_cors=True)
 
-    @defer.inlineCallbacks
-    def _do_preview(self, url, user, ts):
+    async def _do_preview(self, url, user, ts):
         """Check the db, and download the URL and build a preview
 
         Args:
@@ -179,7 +178,7 @@ class PreviewUrlResource(DirectServeResource):
         """
         # check the URL cache in the DB (which will also provide us with
         # historical previews, if we have any)
-        cache_result = yield self.store.get_url_cache(url, ts)
+        cache_result = await self.store.get_url_cache(url, ts)
         if (
             cache_result
             and cache_result["expires_ts"] > ts
@@ -192,13 +191,13 @@ class PreviewUrlResource(DirectServeResource):
                 og = og.encode("utf8")
             return og
 
-        media_info = yield self._download_url(url, user)
+        media_info = await self._download_url(url, user)
 
         logger.debug("got media_info of '%s'", media_info)
 
         if _is_media(media_info["media_type"]):
             file_id = media_info["filesystem_id"]
-            dims = yield self.media_repo._generate_thumbnails(
+            dims = await self.media_repo._generate_thumbnails(
                 None, file_id, file_id, media_info["media_type"], url_cache=True
             )
 
@@ -248,14 +247,14 @@ class PreviewUrlResource(DirectServeResource):
             # request itself and benefit from the same caching etc.  But for now we
             # just rely on the caching on the master request to speed things up.
             if "og:image" in og and og["og:image"]:
-                image_info = yield self._download_url(
+                image_info = await self._download_url(
                     _rebase_url(og["og:image"], media_info["uri"]), user
                 )
 
                 if _is_media(image_info["media_type"]):
                     # TODO: make sure we don't choke on white-on-transparent images
                     file_id = image_info["filesystem_id"]
-                    dims = yield self.media_repo._generate_thumbnails(
+                    dims = await self.media_repo._generate_thumbnails(
                         None, file_id, file_id, image_info["media_type"], url_cache=True
                     )
                     if dims:
@@ -293,7 +292,7 @@ class PreviewUrlResource(DirectServeResource):
         jsonog = json.dumps(og)
 
         # store OG in history-aware DB cache
-        yield self.store.store_url_cache(
+        await self.store.store_url_cache(
             url,
             media_info["response_code"],
             media_info["etag"],
@@ -305,8 +304,7 @@ class PreviewUrlResource(DirectServeResource):
 
         return jsonog.encode("utf8")
 
-    @defer.inlineCallbacks
-    def _download_url(self, url, user):
+    async def _download_url(self, url, user):
         # TODO: we should probably honour robots.txt... except in practice
         # we're most likely being explicitly triggered by a human rather than a
         # bot, so are we really a robot?
@@ -318,7 +316,7 @@ class PreviewUrlResource(DirectServeResource):
         with self.media_storage.store_into_file(file_info) as (f, fname, finish):
             try:
                 logger.debug("Trying to get url '%s'", url)
-                length, headers, uri, code = yield self.client.get_file(
+                length, headers, uri, code = await self.client.get_file(
                     url, output_stream=f, max_size=self.max_spider_size
                 )
             except SynapseError:
@@ -345,7 +343,7 @@ class PreviewUrlResource(DirectServeResource):
                     % (traceback.format_exception_only(sys.exc_info()[0], e),),
                     Codes.UNKNOWN,
                 )
-            yield finish()
+            await finish()
 
         try:
             if b"Content-Type" in headers:
@@ -356,7 +354,7 @@ class PreviewUrlResource(DirectServeResource):
 
             download_name = get_filename_from_headers(headers)
 
-            yield self.store.store_local_media(
+            await self.store.store_local_media(
                 media_id=file_id,
                 media_type=media_type,
                 time_now_ms=self.clock.time_msec(),
@@ -393,8 +391,7 @@ class PreviewUrlResource(DirectServeResource):
             "expire_url_cache_data", self._expire_url_cache_data
         )
 
-    @defer.inlineCallbacks
-    def _expire_url_cache_data(self):
+    async def _expire_url_cache_data(self):
         """Clean up expired url cache content, media and thumbnails.
         """
         # TODO: Delete from backup media store
@@ -403,12 +400,12 @@ class PreviewUrlResource(DirectServeResource):
 
         logger.info("Running url preview cache expiry")
 
-        if not (yield self.store.db.updates.has_completed_background_updates()):
+        if not (await self.store.db.updates.has_completed_background_updates()):
             logger.info("Still running DB updates; skipping expiry")
             return
 
         # First we delete expired url cache entries
-        media_ids = yield self.store.get_expired_url_cache(now)
+        media_ids = await self.store.get_expired_url_cache(now)
 
         removed_media = []
         for media_id in media_ids:
@@ -430,7 +427,7 @@ class PreviewUrlResource(DirectServeResource):
             except Exception:
                 pass
 
-        yield self.store.delete_url_cache(removed_media)
+        await self.store.delete_url_cache(removed_media)
 
         if removed_media:
             logger.info("Deleted %d entries from url cache", len(removed_media))
@@ -440,7 +437,7 @@ class PreviewUrlResource(DirectServeResource):
         # may have a room open with a preview url thing open).
         # So we wait a couple of days before deleting, just in case.
         expire_before = now - 2 * 24 * 60 * 60 * 1000
-        media_ids = yield self.store.get_url_cache_media_before(expire_before)
+        media_ids = await self.store.get_url_cache_media_before(expire_before)
 
         removed_media = []
         for media_id in media_ids:
@@ -478,7 +475,7 @@ class PreviewUrlResource(DirectServeResource):
             except Exception:
                 pass
 
-        yield self.store.delete_url_cache_media(removed_media)
+        await self.store.delete_url_cache_media(removed_media)
 
         logger.info("Deleted %d media from url cache", len(removed_media))
 
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index d57480f761..0b87220234 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,8 +16,6 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.http.server import (
     DirectServeResource,
     set_cors_headers,
@@ -79,11 +77,10 @@ class ThumbnailResource(DirectServeResource):
                 )
             self.media_repo.mark_recently_accessed(server_name, media_id)
 
-    @defer.inlineCallbacks
-    def _respond_local_thumbnail(
+    async def _respond_local_thumbnail(
         self, request, media_id, width, height, method, m_type
     ):
-        media_info = yield self.store.get_local_media(media_id)
+        media_info = await self.store.get_local_media(media_id)
 
         if not media_info:
             respond_404(request)
@@ -93,7 +90,7 @@ class ThumbnailResource(DirectServeResource):
             respond_404(request)
             return
 
-        thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
+        thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
 
         if thumbnail_infos:
             thumbnail_info = self._select_thumbnail(
@@ -114,14 +111,13 @@ class ThumbnailResource(DirectServeResource):
             t_type = file_info.thumbnail_type
             t_length = thumbnail_info["thumbnail_length"]
 
-            responder = yield self.media_storage.fetch_media(file_info)
-            yield respond_with_responder(request, responder, t_type, t_length)
+            responder = await self.media_storage.fetch_media(file_info)
+            await respond_with_responder(request, responder, t_type, t_length)
         else:
             logger.info("Couldn't find any generated thumbnails")
             respond_404(request)
 
-    @defer.inlineCallbacks
-    def _select_or_generate_local_thumbnail(
+    async def _select_or_generate_local_thumbnail(
         self,
         request,
         media_id,
@@ -130,7 +126,7 @@ class ThumbnailResource(DirectServeResource):
         desired_method,
         desired_type,
     ):
-        media_info = yield self.store.get_local_media(media_id)
+        media_info = await self.store.get_local_media(media_id)
 
         if not media_info:
             respond_404(request)
@@ -140,7 +136,7 @@ class ThumbnailResource(DirectServeResource):
             respond_404(request)
             return
 
-        thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
+        thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
         for info in thumbnail_infos:
             t_w = info["thumbnail_width"] == desired_width
             t_h = info["thumbnail_height"] == desired_height
@@ -162,15 +158,15 @@ class ThumbnailResource(DirectServeResource):
                 t_type = file_info.thumbnail_type
                 t_length = info["thumbnail_length"]
 
-                responder = yield self.media_storage.fetch_media(file_info)
+                responder = await self.media_storage.fetch_media(file_info)
                 if responder:
-                    yield respond_with_responder(request, responder, t_type, t_length)
+                    await respond_with_responder(request, responder, t_type, t_length)
                     return
 
         logger.debug("We don't have a thumbnail of that size. Generating")
 
         # Okay, so we generate one.
-        file_path = yield self.media_repo.generate_local_exact_thumbnail(
+        file_path = await self.media_repo.generate_local_exact_thumbnail(
             media_id,
             desired_width,
             desired_height,
@@ -180,13 +176,12 @@ class ThumbnailResource(DirectServeResource):
         )
 
         if file_path:
-            yield respond_with_file(request, desired_type, file_path)
+            await respond_with_file(request, desired_type, file_path)
         else:
             logger.warning("Failed to generate thumbnail")
             respond_404(request)
 
-    @defer.inlineCallbacks
-    def _select_or_generate_remote_thumbnail(
+    async def _select_or_generate_remote_thumbnail(
         self,
         request,
         server_name,
@@ -196,9 +191,9 @@ class ThumbnailResource(DirectServeResource):
         desired_method,
         desired_type,
     ):
-        media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
+        media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
 
-        thumbnail_infos = yield self.store.get_remote_media_thumbnails(
+        thumbnail_infos = await self.store.get_remote_media_thumbnails(
             server_name, media_id
         )
 
@@ -224,15 +219,15 @@ class ThumbnailResource(DirectServeResource):
                 t_type = file_info.thumbnail_type
                 t_length = info["thumbnail_length"]
 
-                responder = yield self.media_storage.fetch_media(file_info)
+                responder = await self.media_storage.fetch_media(file_info)
                 if responder:
-                    yield respond_with_responder(request, responder, t_type, t_length)
+                    await respond_with_responder(request, responder, t_type, t_length)
                     return
 
         logger.debug("We don't have a thumbnail of that size. Generating")
 
         # Okay, so we generate one.
-        file_path = yield self.media_repo.generate_remote_exact_thumbnail(
+        file_path = await self.media_repo.generate_remote_exact_thumbnail(
             server_name,
             file_id,
             media_id,
@@ -243,21 +238,20 @@ class ThumbnailResource(DirectServeResource):
         )
 
         if file_path:
-            yield respond_with_file(request, desired_type, file_path)
+            await respond_with_file(request, desired_type, file_path)
         else:
             logger.warning("Failed to generate thumbnail")
             respond_404(request)
 
-    @defer.inlineCallbacks
-    def _respond_remote_thumbnail(
+    async def _respond_remote_thumbnail(
         self, request, server_name, media_id, width, height, method, m_type
     ):
         # TODO: Don't download the whole remote file
         # We should proxy the thumbnail from the remote server instead of
         # downloading the remote file and generating our own thumbnails.
-        media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
+        media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
 
-        thumbnail_infos = yield self.store.get_remote_media_thumbnails(
+        thumbnail_infos = await self.store.get_remote_media_thumbnails(
             server_name, media_id
         )
 
@@ -278,8 +272,8 @@ class ThumbnailResource(DirectServeResource):
             t_type = file_info.thumbnail_type
             t_length = thumbnail_info["thumbnail_length"]
 
-            responder = yield self.media_storage.fetch_media(file_info)
-            yield respond_with_responder(request, responder, t_type, t_length)
+            responder = await self.media_storage.fetch_media(file_info)
+            await respond_with_responder(request, responder, t_type, t_length)
         else:
             logger.info("Failed to find any generated thumbnails")
             respond_404(request)
diff --git a/synapse/server.py b/synapse/server.py
index 1b980371de..9228e1c892 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -56,6 +56,7 @@ from synapse.handlers.account_validity import AccountValidityHandler
 from synapse.handlers.acme import AcmeHandler
 from synapse.handlers.appservice import ApplicationServicesHandler
 from synapse.handlers.auth import AuthHandler, MacaroonGenerator
+from synapse.handlers.cas_handler import CasHandler
 from synapse.handlers.deactivate_account import DeactivateAccountHandler
 from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
 from synapse.handlers.devicemessage import DeviceMessageHandler
@@ -66,6 +67,7 @@ from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerH
 from synapse.handlers.initial_sync import InitialSyncHandler
 from synapse.handlers.message import EventCreationHandler, MessageHandler
 from synapse.handlers.pagination import PaginationHandler
+from synapse.handlers.password_policy import PasswordPolicyHandler
 from synapse.handlers.presence import PresenceHandler
 from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler
 from synapse.handlers.read_marker import ReadMarkerHandler
@@ -85,6 +87,7 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient
 from synapse.notifier import Notifier
 from synapse.push.action_generator import ActionGenerator
 from synapse.push.pusherpool import PusherPool
+from synapse.replication.tcp.resource import ReplicationStreamer
 from synapse.rest.media.v1.media_repository import (
     MediaRepository,
     MediaRepositoryResource,
@@ -100,6 +103,7 @@ from synapse.storage import DataStores, Storage
 from synapse.streams.events import EventSources
 from synapse.util import Clock
 from synapse.util.distributor import Distributor
+from synapse.util.stringutils import random_string
 
 logger = logging.getLogger(__name__)
 
@@ -196,9 +200,12 @@ class HomeServer(object):
         "sendmail",
         "registration_handler",
         "account_validity_handler",
+        "cas_handler",
         "saml_handler",
         "event_client_serializer",
+        "password_policy_handler",
         "storage",
+        "replication_streamer",
     ]
 
     REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
@@ -224,6 +231,8 @@ class HomeServer(object):
         self._listening_services = []
         self.start_time = None
 
+        self.instance_id = random_string(5)
+
         self.clock = Clock(reactor)
         self.distributor = Distributor()
         self.ratelimiter = Ratelimiter()
@@ -236,6 +245,14 @@ class HomeServer(object):
         for depname in kwargs:
             setattr(self, depname, kwargs[depname])
 
+    def get_instance_id(self):
+        """A unique ID for this synapse process instance.
+
+        This is used to distinguish running instances in worker-based
+        deployments.
+        """
+        return self.instance_id
+
     def setup(self):
         logger.info("Setting up.")
         self.start_time = int(self.get_clock().time())
@@ -525,6 +542,9 @@ class HomeServer(object):
     def build_account_validity_handler(self):
         return AccountValidityHandler(self)
 
+    def build_cas_handler(self):
+        return CasHandler(self)
+
     def build_saml_handler(self):
         from synapse.handlers.saml_handler import SamlHandler
 
@@ -533,9 +553,15 @@ class HomeServer(object):
     def build_event_client_serializer(self):
         return EventClientSerializer(self)
 
+    def build_password_policy_handler(self):
+        return PasswordPolicyHandler(self)
+
     def build_storage(self) -> Storage:
         return Storage(self, self.datastores)
 
+    def build_replication_streamer(self) -> ReplicationStreamer:
+        return ReplicationStreamer(self)
+
     def remove_pusher(self, app_id, push_key, user_id):
         return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
 
@@ -557,24 +583,22 @@ def _make_dependency_method(depname):
         try:
             builder = getattr(hs, "build_%s" % (depname))
         except AttributeError:
-            builder = None
+            raise NotImplementedError(
+                "%s has no %s nor a builder for it" % (type(hs).__name__, depname)
+            )
 
-        if builder:
-            # Prevent cyclic dependencies from deadlocking
-            if depname in hs._building:
-                raise ValueError("Cyclic dependency while building %s" % (depname,))
-            hs._building[depname] = 1
+        # Prevent cyclic dependencies from deadlocking
+        if depname in hs._building:
+            raise ValueError("Cyclic dependency while building %s" % (depname,))
 
+        hs._building[depname] = 1
+        try:
             dep = builder()
             setattr(hs, depname, dep)
-
+        finally:
             del hs._building[depname]
 
-            return dep
-
-        raise NotImplementedError(
-            "%s has no %s nor a builder for it" % (type(hs).__name__, depname)
-        )
+        return dep
 
     setattr(HomeServer, "get_%s" % (depname), _get)
 
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 3844f0e12f..9d1dfa71e7 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -114,3 +114,5 @@ class HomeServer(object):
         pass
     def is_mine_id(self, domain_id: str) -> bool:
         pass
+    def get_instance_id(self) -> str:
+        pass
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index acca079f23..649e835303 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -144,7 +144,10 @@ class DataStore(
             db_conn,
             "device_lists_stream",
             "stream_id",
-            extra_tables=[("user_signature_stream", "stream_id")],
+            extra_tables=[
+                ("user_signature_stream", "stream_id"),
+                ("device_lists_outbound_pokes", "stream_id"),
+            ],
         )
         self._cross_signing_id_gen = StreamIdGenerator(
             db_conn, "e2e_cross_signing_keys", "stream_id"
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index d4c44dcc75..4dc5da3fe8 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -32,7 +32,29 @@ logger = logging.getLogger(__name__)
 CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
 
 
-class CacheInvalidationStore(SQLBaseStore):
+class CacheInvalidationWorkerStore(SQLBaseStore):
+    def get_all_updated_caches(self, last_id, current_id, limit):
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_updated_caches_txn(txn):
+            # We purposefully don't bound by the current token, as we want to
+            # send across cache invalidations as quickly as possible. Cache
+            # invalidations are idempotent, so duplicates are fine.
+            sql = (
+                "SELECT stream_id, cache_func, keys, invalidation_ts"
+                " FROM cache_invalidation_stream"
+                " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
+            )
+            txn.execute(sql, (last_id, limit))
+            return txn.fetchall()
+
+        return self.db.runInteraction(
+            "get_all_updated_caches", get_all_updated_caches_txn
+        )
+
+
+class CacheInvalidationStore(CacheInvalidationWorkerStore):
     async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
         """Invalidates the cache and adds it to the cache stream so slaves
         will know to invalidate their caches.
@@ -145,26 +167,6 @@ class CacheInvalidationStore(SQLBaseStore):
                 },
             )
 
-    def get_all_updated_caches(self, last_id, current_id, limit):
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_updated_caches_txn(txn):
-            # We purposefully don't bound by the current token, as we want to
-            # send across cache invalidations as quickly as possible. Cache
-            # invalidations are idempotent, so duplicates are fine.
-            sql = (
-                "SELECT stream_id, cache_func, keys, invalidation_ts"
-                " FROM cache_invalidation_stream"
-                " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
-            )
-            txn.execute(sql, (last_id, limit))
-            return txn.fetchall()
-
-        return self.db.runInteraction(
-            "get_all_updated_caches", get_all_updated_caches_txn
-        )
-
     def get_cache_stream_token(self):
         if self._cache_id_gen:
             return self._cache_id_gen.get_current_token()
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 0613b49f4a..9a1178fb39 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -207,6 +207,50 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
         )
 
+    def get_all_new_device_messages(self, last_pos, current_pos, limit):
+        """
+        Args:
+            last_pos(int):
+            current_pos(int):
+            limit(int):
+        Returns:
+            A deferred list of rows from the device inbox
+        """
+        if last_pos == current_pos:
+            return defer.succeed([])
+
+        def get_all_new_device_messages_txn(txn):
+            # We limit like this as we might have multiple rows per stream_id, and
+            # we want to make sure we always get all entries for any stream_id
+            # we return.
+            upper_pos = min(current_pos, last_pos + limit)
+            sql = (
+                "SELECT max(stream_id), user_id"
+                " FROM device_inbox"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " GROUP BY user_id"
+            )
+            txn.execute(sql, (last_pos, upper_pos))
+            rows = txn.fetchall()
+
+            sql = (
+                "SELECT max(stream_id), destination"
+                " FROM device_federation_outbox"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " GROUP BY destination"
+            )
+            txn.execute(sql, (last_pos, upper_pos))
+            rows.extend(txn)
+
+            # Order by ascending stream ordering
+            rows.sort()
+
+            return rows
+
+        return self.db.runInteraction(
+            "get_all_new_device_messages", get_all_new_device_messages_txn
+        )
+
 
 class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
@@ -411,47 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
                 rows.append((user_id, device_id, stream_id, message_json))
 
         txn.executemany(sql, rows)
-
-    def get_all_new_device_messages(self, last_pos, current_pos, limit):
-        """
-        Args:
-            last_pos(int):
-            current_pos(int):
-            limit(int):
-        Returns:
-            A deferred list of rows from the device inbox
-        """
-        if last_pos == current_pos:
-            return defer.succeed([])
-
-        def get_all_new_device_messages_txn(txn):
-            # We limit like this as we might have multiple rows per stream_id, and
-            # we want to make sure we always get all entries for any stream_id
-            # we return.
-            upper_pos = min(current_pos, last_pos + limit)
-            sql = (
-                "SELECT max(stream_id), user_id"
-                " FROM device_inbox"
-                " WHERE ? < stream_id AND stream_id <= ?"
-                " GROUP BY user_id"
-            )
-            txn.execute(sql, (last_pos, upper_pos))
-            rows = txn.fetchall()
-
-            sql = (
-                "SELECT max(stream_id), destination"
-                " FROM device_federation_outbox"
-                " WHERE ? < stream_id AND stream_id <= ?"
-                " GROUP BY destination"
-            )
-            txn.execute(sql, (last_pos, upper_pos))
-            rows.extend(txn)
-
-            # Order by ascending stream ordering
-            rows.sort()
-
-            return rows
-
-        return self.db.runInteraction(
-            "get_all_new_device_messages", get_all_new_device_messages_txn
-        )
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 8af5f7de54..dd3561e9b2 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -15,6 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import List, Tuple
 
 from six import iteritems
 
@@ -31,7 +32,7 @@ from synapse.logging.opentracing import (
 )
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import Database
+from synapse.storage.database import Database, LoggingTransaction
 from synapse.types import Collection, get_verify_key_from_cross_signing_key
 from synapse.util.caches.descriptors import (
     Cache,
@@ -40,6 +41,7 @@ from synapse.util.caches.descriptors import (
     cachedList,
 )
 from synapse.util.iterutils import batch_iter
+from synapse.util.stringutils import shortstr
 
 logger = logging.getLogger(__name__)
 
@@ -112,23 +114,13 @@ class DeviceWorkerStore(SQLBaseStore):
         if not has_changed:
             return now_stream_id, []
 
-        # We retrieve n+1 devices from the list of outbound pokes where n is
-        # our outbound device update limit. We then check if the very last
-        # device has the same stream_id as the second-to-last device. If so,
-        # then we ignore all devices with that stream_id and only send the
-        # devices with a lower stream_id.
-        #
-        # If when culling the list we end up with no devices afterwards, we
-        # consider the device update to be too large, and simply skip the
-        # stream_id; the rationale being that such a large device list update
-        # is likely an error.
         updates = yield self.db.runInteraction(
             "get_device_updates_by_remote",
             self._get_device_updates_by_remote_txn,
             destination,
             from_stream_id,
             now_stream_id,
-            limit + 1,
+            limit,
         )
 
         # Return an empty list if there are no updates
@@ -166,14 +158,6 @@ class DeviceWorkerStore(SQLBaseStore):
                     "device_id": verify_key.version,
                 }
 
-        # if we have exceeded the limit, we need to exclude any results with the
-        # same stream_id as the last row.
-        if len(updates) > limit:
-            stream_id_cutoff = updates[-1][2]
-            now_stream_id = stream_id_cutoff - 1
-        else:
-            stream_id_cutoff = None
-
         # Perform the equivalent of a GROUP BY
         #
         # Iterate through the updates list and copy non-duplicate
@@ -181,7 +165,6 @@ class DeviceWorkerStore(SQLBaseStore):
         # the max stream_id across each set of duplicate entries
         #
         # maps (user_id, device_id) -> (stream_id, opentracing_context)
-        # as long as their stream_id does not match that of the last row
         #
         # opentracing_context contains the opentracing metadata for the request
         # that created the poke
@@ -192,10 +175,6 @@ class DeviceWorkerStore(SQLBaseStore):
         query_map = {}
         cross_signing_keys_by_user = {}
         for user_id, device_id, update_stream_id, update_context in updates:
-            if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
-                # Stop processing updates
-                break
-
             if (
                 user_id in master_key_by_user
                 and device_id == master_key_by_user[user_id]["device_id"]
@@ -218,17 +197,6 @@ class DeviceWorkerStore(SQLBaseStore):
                 if update_stream_id > previous_update_stream_id:
                     query_map[key] = (update_stream_id, update_context)
 
-        # If we didn't find any updates with a stream_id lower than the cutoff, it
-        # means that there are more than limit updates all of which have the same
-        # steam_id.
-
-        # That should only happen if a client is spamming the server with new
-        # devices, in which case E2E isn't going to work well anyway. We'll just
-        # skip that stream_id and return an empty list, and continue with the next
-        # stream_id next time.
-        if not query_map and not cross_signing_keys_by_user:
-            return stream_id_cutoff, []
-
         results = yield self._get_device_update_edus_by_remote(
             destination, from_stream_id, query_map
         )
@@ -301,7 +269,14 @@ class DeviceWorkerStore(SQLBaseStore):
             prev_id = yield self._get_last_device_update_for_remote_user(
                 destination, user_id, from_stream_id
             )
-            for device_id, device in iteritems(user_devices):
+
+            # make sure we go through the devices in stream order
+            device_ids = sorted(
+                user_devices.keys(), key=lambda i: query_map[(user_id, i)][0],
+            )
+
+            for device_id in device_ids:
+                device = user_devices[device_id]
                 stream_id, opentracing_context = query_map[(user_id, device_id)]
                 result = {
                     "user_id": user_id,
@@ -611,22 +586,33 @@ class DeviceWorkerStore(SQLBaseStore):
         else:
             return set()
 
-    def get_all_device_list_changes_for_remotes(self, from_key, to_key):
-        """Return a list of `(stream_id, user_id, destination)` which is the
-        combined list of changes to devices, and which destinations need to be
-        poked. `destination` may be None if no destinations need to be poked.
+    async def get_all_device_list_changes_for_remotes(
+        self, from_key: int, to_key: int, limit: int,
+    ) -> List[Tuple[int, str]]:
+        """Return a list of `(stream_id, entity)` which is the combined list of
+        changes to devices and which destinations need to be poked. Entity is
+        either a user ID (starting with '@') or a remote destination.
         """
-        # We do a group by here as there can be a large number of duplicate
-        # entries, since we throw away device IDs.
+
+        # This query Does The Right Thing where it'll correctly apply the
+        # bounds to the inner queries.
         sql = """
-            SELECT MAX(stream_id) AS stream_id, user_id, destination
-            FROM device_lists_stream
-            LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
+            SELECT stream_id, entity FROM (
+                SELECT stream_id, user_id AS entity FROM device_lists_stream
+                UNION ALL
+                SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+            ) AS e
             WHERE ? < stream_id AND stream_id <= ?
-            GROUP BY user_id, destination
+            LIMIT ?
         """
-        return self.db.execute(
-            "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
+
+        return await self.db.execute(
+            "get_all_device_list_changes_for_remotes",
+            None,
+            sql,
+            from_key,
+            to_key,
+            limit,
         )
 
     @cached(max_entries=10000)
@@ -1021,29 +1007,49 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         """Persist that a user's devices have been updated, and which hosts
         (if any) should be poked.
         """
-        with self._device_list_id_gen.get_next() as stream_id:
+        if not device_ids:
+            return
+
+        with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
             yield self.db.runInteraction(
-                "add_device_change_to_streams",
-                self._add_device_change_txn,
+                "add_device_change_to_stream",
+                self._add_device_change_to_stream_txn,
+                user_id,
+                device_ids,
+                stream_ids,
+            )
+
+        if not hosts:
+            return stream_ids[-1]
+
+        context = get_active_span_text_map()
+        with self._device_list_id_gen.get_next_mult(
+            len(hosts) * len(device_ids)
+        ) as stream_ids:
+            yield self.db.runInteraction(
+                "add_device_outbound_poke_to_stream",
+                self._add_device_outbound_poke_to_stream_txn,
                 user_id,
                 device_ids,
                 hosts,
-                stream_id,
+                stream_ids,
+                context,
             )
-        return stream_id
 
-    def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
-        now = self._clock.time_msec()
+        return stream_ids[-1]
 
+    def _add_device_change_to_stream_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_ids: Collection[str],
+        stream_ids: List[str],
+    ):
         txn.call_after(
-            self._device_list_stream_cache.entity_has_changed, user_id, stream_id
+            self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
         )
-        for host in hosts:
-            txn.call_after(
-                self._device_list_federation_stream_cache.entity_has_changed,
-                host,
-                stream_id,
-            )
+
+        min_stream_id = stream_ids[0]
 
         # Delete older entries in the table, as we really only care about
         # when the latest change happened.
@@ -1052,7 +1058,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             DELETE FROM device_lists_stream
             WHERE user_id = ? AND device_id = ? AND stream_id < ?
             """,
-            [(user_id, device_id, stream_id) for device_id in device_ids],
+            [(user_id, device_id, min_stream_id) for device_id in device_ids],
         )
 
         self.db.simple_insert_many_txn(
@@ -1060,11 +1066,22 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             table="device_lists_stream",
             values=[
                 {"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
-                for device_id in device_ids
+                for stream_id, device_id in zip(stream_ids, device_ids)
             ],
         )
 
-        context = get_active_span_text_map()
+    def _add_device_outbound_poke_to_stream_txn(
+        self, txn, user_id, device_ids, hosts, stream_ids, context,
+    ):
+        for host in hosts:
+            txn.call_after(
+                self._device_list_federation_stream_cache.entity_has_changed,
+                host,
+                stream_ids[-1],
+            )
+
+        now = self._clock.time_msec()
+        next_stream_id = iter(stream_ids)
 
         self.db.simple_insert_many_txn(
             txn,
@@ -1072,7 +1089,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             values=[
                 {
                     "destination": destination,
-                    "stream_id": stream_id,
+                    "stream_id": next(next_stream_id),
                     "user_id": user_id,
                     "device_id": device_id,
                     "sent": False,
@@ -1086,18 +1103,47 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             ],
         )
 
-    def _prune_old_outbound_device_pokes(self):
+    def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000):
         """Delete old entries out of the device_lists_outbound_pokes to ensure
-        that we don't fill up due to dead servers. We keep one entry per
-        (destination, user_id) tuple to ensure that the prev_ids remain correct
-        if the server does come back.
+        that we don't fill up due to dead servers.
+
+        Normally, we try to send device updates as a delta since a previous known point:
+        this is done by setting the prev_id in the m.device_list_update EDU. However,
+        for that to work, we have to have a complete record of each change to
+        each device, which can add up to quite a lot of data.
+
+        An alternative mechanism is that, if the remote server sees that it has missed
+        an entry in the stream_id sequence for a given user, it will request a full
+        list of that user's devices. Hence, we can reduce the amount of data we have to
+        store (and transmit in some future transaction), by clearing almost everything
+        for a given destination out of the database, and having the remote server
+        resync.
+
+        All we need to do is make sure we keep at least one row for each
+        (user, destination) pair, to remind us to send a m.device_list_update EDU for
+        that user when the destination comes back. It doesn't matter which device
+        we keep.
         """
-        yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000
+        yesterday = self._clock.time_msec() - prune_age
 
         def _prune_txn(txn):
+            # look for (user, destination) pairs which have an update older than
+            # the cutoff.
+            #
+            # For each pair, we also need to know the most recent stream_id, and
+            # an arbitrary device_id at that stream_id.
             select_sql = """
-                SELECT destination, user_id, max(stream_id) as stream_id
-                FROM device_lists_outbound_pokes
+            SELECT
+                dlop1.destination,
+                dlop1.user_id,
+                MAX(dlop1.stream_id) AS stream_id,
+                (SELECT MIN(dlop2.device_id) AS device_id FROM
+                    device_lists_outbound_pokes dlop2
+                    WHERE dlop2.destination = dlop1.destination AND
+                      dlop2.user_id=dlop1.user_id AND
+                      dlop2.stream_id=MAX(dlop1.stream_id)
+                )
+            FROM device_lists_outbound_pokes dlop1
                 GROUP BY destination, user_id
                 HAVING min(ts) < ? AND count(*) > 1
             """
@@ -1108,14 +1154,29 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             if not rows:
                 return
 
+            logger.info(
+                "Pruning old outbound device list updates for %i users/destinations: %s",
+                len(rows),
+                shortstr((row[0], row[1]) for row in rows),
+            )
+
+            # we want to keep the update with the highest stream_id for each user.
+            #
+            # there might be more than one update (with different device_ids) with the
+            # same stream_id, so we also delete all but one rows with the max stream id.
             delete_sql = """
                 DELETE FROM device_lists_outbound_pokes
-                WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ?
+                WHERE destination = ? AND user_id = ? AND (
+                    stream_id < ? OR
+                    (stream_id = ? AND device_id != ?)
+                )
             """
-
-            txn.executemany(
-                delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows)
-            )
+            count = 0
+            for (destination, user_id, stream_id, device_id) in rows:
+                txn.execute(
+                    delete_sql, (destination, user_id, stream_id, stream_id, device_id)
+                )
+                count += txn.rowcount
 
             # Since we've deleted unsent deltas, we need to remove the entry
             # of last successful sent so that the prev_ids are correctly set.
@@ -1125,7 +1186,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             """
             txn.executemany(sql, ((row[0], row[1]) for row in rows))
 
-            logger.info("Pruned %d device list outbound pokes", txn.rowcount)
+            logger.info("Pruned %d device list outbound pokes", count)
 
         return run_as_background_process(
             "prune_old_outbound_device_pokes",
diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/data_stores/main/directory.py
index c9e7de7d12..e1d1bc3e05 100644
--- a/synapse/storage/data_stores/main/directory.py
+++ b/synapse/storage/data_stores/main/directory.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 from collections import namedtuple
+from typing import Optional
 
 from twisted.internet import defer
 
@@ -159,10 +160,29 @@ class DirectoryStore(DirectoryWorkerStore):
 
         return room_id
 
-    def update_aliases_for_room(self, old_room_id, new_room_id, creator):
+    def update_aliases_for_room(
+        self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
+    ):
+        """Repoint all of the aliases for a given room, to a different room.
+
+        Args:
+            old_room_id:
+            new_room_id:
+            creator: The user to record as the creator of the new mapping.
+                If None, the creator will be left unchanged.
+        """
+
         def _update_aliases_for_room_txn(txn):
-            sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
-            txn.execute(sql, (new_room_id, creator, old_room_id))
+            update_creator_sql = ""
+            sql_params = (new_room_id, old_room_id)
+            if creator:
+                update_creator_sql = ", creator = ?"
+                sql_params = (new_room_id, creator, old_room_id)
+
+            sql = "UPDATE room_aliases SET room_id = ? %s WHERE room_id = ?" % (
+                update_creator_sql,
+            )
+            txn.execute(sql, sql_params)
             self._invalidate_cache_and_stream(
                 txn, self.get_aliases_for_room, (old_room_id,)
             )
diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py
index 84594cf0a9..23f4570c4b 100644
--- a/synapse/storage/data_stores/main/e2e_room_keys.py
+++ b/synapse/storage/data_stores/main/e2e_room_keys.py
@@ -146,7 +146,8 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             room_entry["sessions"][row["session_id"]] = {
                 "first_message_index": row["first_message_index"],
                 "forwarded_count": row["forwarded_count"],
-                "is_verified": row["is_verified"],
+                # is_verified must be returned to the client as a boolean
+                "is_verified": bool(row["is_verified"]),
                 "session_data": json.loads(row["session_data"]),
             }
 
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 001a53f9b4..bcf746b7ef 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -537,7 +537,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
 
         return result
 
-    def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
+    def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit):
         """Return a list of changes from the user signature stream to notify remotes.
         Note that the user signature stream represents when a user signs their
         device with their user-signing key, which is not published to other
@@ -552,13 +552,19 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
         """
         sql = """
-            SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id
+            SELECT stream_id, from_user_id AS user_id
             FROM user_signature_stream
             WHERE ? < stream_id AND stream_id <= ?
-            GROUP BY user_id
+            ORDER BY stream_id ASC
+            LIMIT ?
         """
         return self.db.execute(
-            "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
+            "get_all_user_signature_changes_for_remotes",
+            None,
+            sql,
+            from_key,
+            to_key,
+            limit,
         )
 
 
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index d593ef47b8..e71c23541d 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -1267,104 +1267,6 @@ class EventsStore(
         ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
         return ret
 
-    def get_current_backfill_token(self):
-        """The current minimum token that backfilled events have reached"""
-        return -self._backfill_id_gen.get_current_token()
-
-    def get_current_events_token(self):
-        """The current maximum token that events have reached"""
-        return self._stream_id_gen.get_current_token()
-
-    def get_all_new_forward_event_rows(self, last_id, current_id, limit):
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_new_forward_event_rows(txn):
-            sql = (
-                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
-                " FROM events AS e"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " LEFT JOIN event_relations USING (event_id)"
-                " WHERE ? < stream_ordering AND stream_ordering <= ?"
-                " ORDER BY stream_ordering ASC"
-                " LIMIT ?"
-            )
-            txn.execute(sql, (last_id, current_id, limit))
-            new_event_updates = txn.fetchall()
-
-            if len(new_event_updates) == limit:
-                upper_bound = new_event_updates[-1][0]
-            else:
-                upper_bound = current_id
-
-            sql = (
-                "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
-                " FROM events AS e"
-                " INNER JOIN ex_outlier_stream USING (event_id)"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " LEFT JOIN event_relations USING (event_id)"
-                " WHERE ? < event_stream_ordering"
-                " AND event_stream_ordering <= ?"
-                " ORDER BY event_stream_ordering DESC"
-            )
-            txn.execute(sql, (last_id, upper_bound))
-            new_event_updates.extend(txn)
-
-            return new_event_updates
-
-        return self.db.runInteraction(
-            "get_all_new_forward_event_rows", get_all_new_forward_event_rows
-        )
-
-    def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_new_backfill_event_rows(txn):
-            sql = (
-                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
-                " FROM events AS e"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " LEFT JOIN event_relations USING (event_id)"
-                " WHERE ? > stream_ordering AND stream_ordering >= ?"
-                " ORDER BY stream_ordering ASC"
-                " LIMIT ?"
-            )
-            txn.execute(sql, (-last_id, -current_id, limit))
-            new_event_updates = txn.fetchall()
-
-            if len(new_event_updates) == limit:
-                upper_bound = new_event_updates[-1][0]
-            else:
-                upper_bound = current_id
-
-            sql = (
-                "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
-                " FROM events AS e"
-                " INNER JOIN ex_outlier_stream USING (event_id)"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " LEFT JOIN event_relations USING (event_id)"
-                " WHERE ? > event_stream_ordering"
-                " AND event_stream_ordering >= ?"
-                " ORDER BY event_stream_ordering DESC"
-            )
-            txn.execute(sql, (-last_id, -upper_bound))
-            new_event_updates.extend(txn.fetchall())
-
-            return new_event_updates
-
-        return self.db.runInteraction(
-            "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
-        )
-
     @cached(num_args=5, max_entries=10)
     def get_all_new_events(
         self,
@@ -1850,22 +1752,6 @@ class EventsStore(
 
         return (int(res["topological_ordering"]), int(res["stream_ordering"]))
 
-    def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
-        def get_all_updated_current_state_deltas_txn(txn):
-            sql = """
-                SELECT stream_id, room_id, type, state_key, event_id
-                FROM current_state_delta_stream
-                WHERE ? < stream_id AND stream_id <= ?
-                ORDER BY stream_id ASC LIMIT ?
-            """
-            txn.execute(sql, (from_token, to_token, limit))
-            return txn.fetchall()
-
-        return self.db.runInteraction(
-            "get_all_updated_current_state_deltas",
-            get_all_updated_current_state_deltas_txn,
-        )
-
     def insert_labels_for_event_txn(
         self, txn, event_id, labels, room_id, topological_ordering
     ):
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index ca237c6f12..16ea8948b1 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -35,7 +35,7 @@ from synapse.api.room_versions import (
 )
 from synapse.events import make_event_from_dict
 from synapse.events.utils import prune_event
-from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import PreserveLoggingContext, current_context
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.database import Database
@@ -409,7 +409,7 @@ class EventsWorkerStore(SQLBaseStore):
         missing_events_ids = [e for e in event_ids if e not in event_entry_map]
 
         if missing_events_ids:
-            log_ctx = LoggingContext.current_context()
+            log_ctx = current_context()
             log_ctx.record_event_fetch(len(missing_events_ids))
 
             # Note that _get_events_from_db is also responsible for turning db rows
@@ -963,3 +963,117 @@ class EventsWorkerStore(SQLBaseStore):
         complexity_v1 = round(state_events / 500, 2)
 
         return {"v1": complexity_v1}
+
+    def get_current_backfill_token(self):
+        """The current minimum token that backfilled events have reached"""
+        return -self._backfill_id_gen.get_current_token()
+
+    def get_current_events_token(self):
+        """The current maximum token that events have reached"""
+        return self._stream_id_gen.get_current_token()
+
+    def get_all_new_forward_event_rows(self, last_id, current_id, limit):
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_new_forward_event_rows(txn):
+            sql = (
+                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts, relates_to_id"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN event_relations USING (event_id)"
+                " WHERE ? < stream_ordering AND stream_ordering <= ?"
+                " ORDER BY stream_ordering ASC"
+                " LIMIT ?"
+            )
+            txn.execute(sql, (last_id, current_id, limit))
+            new_event_updates = txn.fetchall()
+
+            if len(new_event_updates) == limit:
+                upper_bound = new_event_updates[-1][0]
+            else:
+                upper_bound = current_id
+
+            sql = (
+                "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts, relates_to_id"
+                " FROM events AS e"
+                " INNER JOIN ex_outlier_stream USING (event_id)"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN event_relations USING (event_id)"
+                " WHERE ? < event_stream_ordering"
+                " AND event_stream_ordering <= ?"
+                " ORDER BY event_stream_ordering DESC"
+            )
+            txn.execute(sql, (last_id, upper_bound))
+            new_event_updates.extend(txn)
+
+            return new_event_updates
+
+        return self.db.runInteraction(
+            "get_all_new_forward_event_rows", get_all_new_forward_event_rows
+        )
+
+    def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_new_backfill_event_rows(txn):
+            sql = (
+                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts, relates_to_id"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN event_relations USING (event_id)"
+                " WHERE ? > stream_ordering AND stream_ordering >= ?"
+                " ORDER BY stream_ordering ASC"
+                " LIMIT ?"
+            )
+            txn.execute(sql, (-last_id, -current_id, limit))
+            new_event_updates = txn.fetchall()
+
+            if len(new_event_updates) == limit:
+                upper_bound = new_event_updates[-1][0]
+            else:
+                upper_bound = current_id
+
+            sql = (
+                "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts, relates_to_id"
+                " FROM events AS e"
+                " INNER JOIN ex_outlier_stream USING (event_id)"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN event_relations USING (event_id)"
+                " WHERE ? > event_stream_ordering"
+                " AND event_stream_ordering >= ?"
+                " ORDER BY event_stream_ordering DESC"
+            )
+            txn.execute(sql, (-last_id, -upper_bound))
+            new_event_updates.extend(txn.fetchall())
+
+            return new_event_updates
+
+        return self.db.runInteraction(
+            "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
+        )
+
+    def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
+        def get_all_updated_current_state_deltas_txn(txn):
+            sql = """
+                SELECT stream_id, room_id, type, state_key, event_id
+                FROM current_state_delta_stream
+                WHERE ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC LIMIT ?
+            """
+            txn.execute(sql, (from_token, to_token, limit))
+            return txn.fetchall()
+
+        return self.db.runInteraction(
+            "get_all_updated_current_state_deltas",
+            get_all_updated_current_state_deltas_txn,
+        )
diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py
index 80ca36dedf..cf195f8aa6 100644
--- a/synapse/storage/data_stores/main/media_repository.py
+++ b/synapse/storage/data_stores/main/media_repository.py
@@ -340,7 +340,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             "get_expired_url_cache", _get_expired_url_cache_txn
         )
 
-    def delete_url_cache(self, media_ids):
+    async def delete_url_cache(self, media_ids):
         if len(media_ids) == 0:
             return
 
@@ -349,7 +349,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         def _delete_url_cache_txn(txn):
             txn.executemany(sql, [(media_id,) for media_id in media_ids])
 
-        return self.db.runInteraction("delete_url_cache", _delete_url_cache_txn)
+        return await self.db.runInteraction("delete_url_cache", _delete_url_cache_txn)
 
     def get_url_cache_media_before(self, before_ts):
         sql = (
diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py
index 604c8b7ddd..dab31e0c2d 100644
--- a/synapse/storage/data_stores/main/presence.py
+++ b/synapse/storage/data_stores/main/presence.py
@@ -60,7 +60,7 @@ class PresenceStore(SQLBaseStore):
                     "status_msg": state.status_msg,
                     "currently_active": state.currently_active,
                 }
-                for state in presence_states
+                for stream_id, state in zip(stream_orderings, presence_states)
             ],
         )
 
@@ -73,19 +73,22 @@ class PresenceStore(SQLBaseStore):
             )
             txn.execute(sql + clause, [stream_id] + list(args))
 
-    def get_all_presence_updates(self, last_id, current_id):
+    def get_all_presence_updates(self, last_id, current_id, limit):
         if last_id == current_id:
             return defer.succeed([])
 
         def get_all_presence_updates_txn(txn):
-            sql = (
-                "SELECT stream_id, user_id, state, last_active_ts,"
-                " last_federation_update_ts, last_user_sync_ts, status_msg,"
-                " currently_active"
-                " FROM presence_stream"
-                " WHERE ? < stream_id AND stream_id <= ?"
-            )
-            txn.execute(sql, (last_id, current_id))
+            sql = """
+                SELECT stream_id, user_id, state, last_active_ts,
+                    last_federation_update_ts, last_user_sync_ts,
+                    status_msg,
+                currently_active
+                FROM presence_stream
+                WHERE ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC
+                LIMIT ?
+            """
+            txn.execute(sql, (last_id, current_id, limit))
             return txn.fetchall()
 
         return self.db.runInteraction(
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index 62ac88d9f2..46f9bda773 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -41,6 +41,7 @@ def _load_rules(rawrules, enabled_map):
         rule = dict(rawrule)
         rule["conditions"] = json.loads(rawrule["conditions"])
         rule["actions"] = json.loads(rawrule["actions"])
+        rule["default"] = False
         ruleslist.append(rule)
 
     # We're going to be mutating this a lot, so do a deep copy
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index e6c10c6316..aaebe427d3 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -732,6 +732,26 @@ class RoomWorkerStore(SQLBaseStore):
 
         return total_media_quarantined
 
+    def get_all_new_public_rooms(self, prev_id, current_id, limit):
+        def get_all_new_public_rooms(txn):
+            sql = """
+                SELECT stream_id, room_id, visibility, appservice_id, network_id
+                FROM public_room_list_stream
+                WHERE stream_id > ? AND stream_id <= ?
+                ORDER BY stream_id ASC
+                LIMIT ?
+            """
+
+            txn.execute(sql, (prev_id, current_id, limit))
+            return txn.fetchall()
+
+        if prev_id == current_id:
+            return defer.succeed([])
+
+        return self.db.runInteraction(
+            "get_all_new_public_rooms", get_all_new_public_rooms
+        )
+
 
 class RoomBackgroundUpdateStore(SQLBaseStore):
     REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -1249,26 +1269,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
     def get_current_public_room_stream_id(self):
         return self._public_room_id_gen.get_current_token()
 
-    def get_all_new_public_rooms(self, prev_id, current_id, limit):
-        def get_all_new_public_rooms(txn):
-            sql = """
-                SELECT stream_id, room_id, visibility, appservice_id, network_id
-                FROM public_room_list_stream
-                WHERE stream_id > ? AND stream_id <= ?
-                ORDER BY stream_id ASC
-                LIMIT ?
-            """
-
-            txn.execute(sql, (prev_id, current_id, limit))
-            return txn.fetchall()
-
-        if prev_id == current_id:
-            return defer.succeed([])
-
-        return self.db.runInteraction(
-            "get_all_new_public_rooms", get_all_new_public_rooms
-        )
-
     @defer.inlineCallbacks
     def block_room(self, room_id, user_id):
         """Marks the room as blocked. Can be called multiple times.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index e61595336c..715c0346dd 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -32,6 +32,7 @@ from synapse.config.database import DatabaseConnectionConfig
 from synapse.logging.context import (
     LoggingContext,
     LoggingContextOrSentinel,
+    current_context,
     make_deferred_yieldable,
 )
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -483,7 +484,7 @@ class Database(object):
             end = monotonic_time()
             duration = end - start
 
-            LoggingContext.current_context().add_database_transaction(duration)
+            current_context().add_database_transaction(duration)
 
             transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
 
@@ -510,7 +511,7 @@ class Database(object):
         after_callbacks = []  # type: List[_CallbackListEntry]
         exception_callbacks = []  # type: List[_CallbackListEntry]
 
-        if LoggingContext.current_context() == LoggingContext.sentinel:
+        if not current_context():
             logger.warning("Starting db txn '%s' from sentinel context", desc)
 
         try:
@@ -547,10 +548,8 @@ class Database(object):
         Returns:
             Deferred: The result of func
         """
-        parent_context = (
-            LoggingContext.current_context()
-        )  # type: Optional[LoggingContextOrSentinel]
-        if parent_context == LoggingContext.sentinel:
+        parent_context = current_context()  # type: Optional[LoggingContextOrSentinel]
+        if not parent_context:
             logger.warning(
                 "Starting db connection from sentinel context: metrics will be lost"
             )
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 7b18455469..ec61e14423 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -21,7 +21,7 @@ from prometheus_client import Counter
 
 from twisted.internet import defer
 
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import LoggingContext, current_context
 from synapse.metrics import InFlightGauge
 
 logger = logging.getLogger(__name__)
@@ -106,7 +106,7 @@ class Measure(object):
             raise RuntimeError("Measure() objects cannot be re-used")
 
         self.start = self.clock.time()
-        parent_context = LoggingContext.current_context()
+        parent_context = current_context()
         self._logging_context = LoggingContext(
             "Measure[%s]" % (self.name,), parent_context
         )
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 3925927f9f..fdff195771 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -32,7 +32,7 @@ def do_patch():
     Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
     """
 
-    from synapse.logging.context import LoggingContext
+    from synapse.logging.context import current_context
 
     global _already_patched
 
@@ -43,35 +43,35 @@ def do_patch():
     def new_inline_callbacks(f):
         @functools.wraps(f)
         def wrapped(*args, **kwargs):
-            start_context = LoggingContext.current_context()
+            start_context = current_context()
             changes = []  # type: List[str]
             orig = orig_inline_callbacks(_check_yield_points(f, changes))
 
             try:
                 res = orig(*args, **kwargs)
             except Exception:
-                if LoggingContext.current_context() != start_context:
+                if current_context() != start_context:
                     for err in changes:
                         print(err, file=sys.stderr)
 
                     err = "%s changed context from %s to %s on exception" % (
                         f,
                         start_context,
-                        LoggingContext.current_context(),
+                        current_context(),
                     )
                     print(err, file=sys.stderr)
                     raise Exception(err)
                 raise
 
             if not isinstance(res, Deferred) or res.called:
-                if LoggingContext.current_context() != start_context:
+                if current_context() != start_context:
                     for err in changes:
                         print(err, file=sys.stderr)
 
                     err = "Completed %s changed context from %s to %s" % (
                         f,
                         start_context,
-                        LoggingContext.current_context(),
+                        current_context(),
                     )
                     # print the error to stderr because otherwise all we
                     # see in travis-ci is the 500 error
@@ -79,23 +79,23 @@ def do_patch():
                     raise Exception(err)
                 return res
 
-            if LoggingContext.current_context() != LoggingContext.sentinel:
+            if current_context():
                 err = (
                     "%s returned incomplete deferred in non-sentinel context "
                     "%s (start was %s)"
-                ) % (f, LoggingContext.current_context(), start_context)
+                ) % (f, current_context(), start_context)
                 print(err, file=sys.stderr)
                 raise Exception(err)
 
             def check_ctx(r):
-                if LoggingContext.current_context() != start_context:
+                if current_context() != start_context:
                     for err in changes:
                         print(err, file=sys.stderr)
                     err = "%s completion of %s changed context from %s to %s" % (
                         "Failure" if isinstance(r, Failure) else "Success",
                         f,
                         start_context,
-                        LoggingContext.current_context(),
+                        current_context(),
                     )
                     print(err, file=sys.stderr)
                     raise Exception(err)
@@ -127,7 +127,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
         function
     """
 
-    from synapse.logging.context import LoggingContext
+    from synapse.logging.context import current_context
 
     @functools.wraps(f)
     def check_yield_points_inner(*args, **kwargs):
@@ -136,7 +136,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
         last_yield_line_no = gen.gi_frame.f_lineno
         result = None  # type: Any
         while True:
-            expected_context = LoggingContext.current_context()
+            expected_context = current_context()
 
             try:
                 isFailure = isinstance(result, Failure)
@@ -145,7 +145,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
                 else:
                     d = gen.send(result)
             except (StopIteration, defer._DefGen_Return) as e:
-                if LoggingContext.current_context() != expected_context:
+                if current_context() != expected_context:
                     # This happens when the context is lost sometime *after* the
                     # final yield and returning. E.g. we forgot to yield on a
                     # function that returns a deferred.
@@ -159,7 +159,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
                         % (
                             f.__qualname__,
                             expected_context,
-                            LoggingContext.current_context(),
+                            current_context(),
                             f.__code__.co_filename,
                             last_yield_line_no,
                         )
@@ -173,13 +173,13 @@ def _check_yield_points(f: Callable, changes: List[str]):
                 # This happens if we yield on a deferred that doesn't follow
                 # the log context rules without wrapping in a `make_deferred_yieldable`.
                 # We raise here as this should never happen.
-                if LoggingContext.current_context() is not LoggingContext.sentinel:
+                if current_context():
                     err = (
                         "%s yielded with context %s rather than sentinel,"
                         " yielded on line %d in %s"
                         % (
                             frame.f_code.co_name,
-                            LoggingContext.current_context(),
+                            current_context(),
                             frame.f_lineno,
                             frame.f_code.co_filename,
                         )
@@ -191,7 +191,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
             except Exception as e:
                 result = Failure(e)
 
-            if LoggingContext.current_context() != expected_context:
+            if current_context() != expected_context:
 
                 # This happens because the context is lost sometime *after* the
                 # previous yield and *after* the current yield. E.g. the
@@ -206,7 +206,7 @@ def _check_yield_points(f: Callable, changes: List[str]):
                     % (
                         frame.f_code.co_name,
                         expected_context,
-                        LoggingContext.current_context(),
+                        current_context(),
                         last_yield_line_no,
                         frame.f_lineno,
                         frame.f_code.co_filename,
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 2c0dcb5208..6899bcb788 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -13,10 +13,11 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import itertools
 import random
 import re
 import string
+from collections import Iterable
 
 import six
 from six import PY2, PY3
@@ -126,3 +127,21 @@ def assert_valid_client_secret(client_secret):
         raise SynapseError(
             400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
         )
+
+
+def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
+    """If iterable has maxitems or fewer, return the stringification of a list
+    containing those items.
+
+    Otherwise, return the stringification of a a list with the first maxitems items,
+    followed by "...".
+
+    Args:
+        iterable: iterable to truncate
+        maxitems: number of items to return before truncating
+    """
+
+    items = list(itertools.islice(iterable, maxitems + 1))
+    if len(items) <= maxitems:
+        return str(items)
+    return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"