summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py8
-rw-r--r--synapse/api/errors.py27
-rw-r--r--synapse/app/admin_cmd.py3
-rw-r--r--synapse/app/appservice.py5
-rw-r--r--synapse/app/federation_sender.py17
-rw-r--r--synapse/app/pusher.py5
-rw-r--r--synapse/app/synchrotron.py5
-rw-r--r--synapse/app/user_dir.py5
-rw-r--r--synapse/config/emailconfig.py222
-rw-r--r--synapse/config/push.py2
-rw-r--r--synapse/config/registration.py83
-rw-r--r--synapse/config/saml2_config.py1
-rw-r--r--synapse/events/snapshot.py11
-rw-r--r--synapse/federation/send_queue.py4
-rw-r--r--synapse/federation/sender/__init__.py18
-rw-r--r--synapse/federation/sender/per_destination_queue.py3
-rw-r--r--synapse/federation/transport/server.py19
-rw-r--r--synapse/handlers/admin.py27
-rw-r--r--synapse/handlers/deactivate_account.py2
-rw-r--r--synapse/handlers/federation.py10
-rw-r--r--synapse/handlers/initial_sync.py2
-rw-r--r--synapse/handlers/room.py24
-rw-r--r--synapse/handlers/room_member.py2
-rw-r--r--synapse/handlers/saml_handler.py40
-rw-r--r--synapse/handlers/search.py2
-rw-r--r--synapse/handlers/sync.py2
-rw-r--r--synapse/handlers/typing.py2
-rw-r--r--synapse/http/server.py13
-rw-r--r--synapse/module_api/__init__.py47
-rw-r--r--synapse/module_api/errors.py18
-rw-r--r--synapse/notifier.py31
-rw-r--r--synapse/push/push_tools.py2
-rw-r--r--synapse/replication/slave/storage/events.py2
-rw-r--r--synapse/replication/tcp/client.py14
-rw-r--r--synapse/replication/tcp/commands.py17
-rw-r--r--synapse/replication/tcp/protocol.py87
-rw-r--r--synapse/replication/tcp/resource.py40
-rw-r--r--synapse/replication/tcp/streams/_base.py25
-rw-r--r--synapse/replication/tcp/streams/events.py9
-rw-r--r--synapse/rest/client/v2_alpha/auth.py4
-rw-r--r--synapse/rest/client/versions.py2
-rw-r--r--synapse/server.pyi12
-rw-r--r--synapse/server_notices/server_notices_manager.py2
-rw-r--r--synapse/state/__init__.py5
-rw-r--r--synapse/state/v1.py5
-rw-r--r--synapse/state/v2.py9
-rw-r--r--synapse/storage/data_stores/main/cache.py2
-rw-r--r--synapse/storage/data_stores/main/devices.py2
-rw-r--r--synapse/storage/data_stores/main/events.py32
-rw-r--r--synapse/storage/data_stores/main/events_worker.py2
-rw-r--r--synapse/storage/data_stores/main/keys.py2
-rw-r--r--synapse/storage/data_stores/main/presence.py2
-rw-r--r--synapse/storage/data_stores/main/roommember.py189
-rw-r--r--synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py97
-rw-r--r--synapse/storage/data_stores/main/state.py11
-rw-r--r--synapse/storage/data_stores/state/store.py52
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/purge_events.py2
-rw-r--r--synapse/storage/state.py35
-rw-r--r--synapse/types.py9
-rw-r--r--synapse/util/__init__.py17
-rw-r--r--synapse/util/iterutils.py48
-rw-r--r--synapse/util/module_loader.py2
64 files changed, 922 insertions, 481 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 0dd538d804..abd5297390 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -36,7 +36,7 @@ try:
 except ImportError:
     pass
 
-__version__ = "1.8.0"
+__version__ = "1.9.0.dev1"
 
 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     # We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index abbc7079a3..2cbfab2569 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -14,7 +14,6 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, Tuple
 
 from six import itervalues
 
@@ -35,7 +34,7 @@ from synapse.api.errors import (
     ResourceLimitError,
 )
 from synapse.config.server import is_threepid_reserved
-from synapse.types import UserID
+from synapse.types import StateMap, UserID
 from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.metrics import Measure
@@ -509,10 +508,7 @@ class Auth(object):
         return self.store.is_server_admin(user)
 
     def compute_auth_events(
-        self,
-        event,
-        current_state_ids: Dict[Tuple[str, str], str],
-        for_verification: bool = False,
+        self, event, current_state_ids: StateMap[str], for_verification: bool = False,
     ):
         """Given an event and current state return the list of event IDs used
         to auth an event.
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 9e9844b47c..1c9456e583 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -17,13 +17,15 @@
 """Contains exceptions and error codes."""
 
 import logging
-from typing import Dict
+from typing import Dict, List
 
 from six import iteritems
 from six.moves import http_client
 
 from canonicaljson import json
 
+from twisted.web import http
+
 logger = logging.getLogger(__name__)
 
 
@@ -80,6 +82,29 @@ class CodeMessageException(RuntimeError):
         self.msg = msg
 
 
+class RedirectException(CodeMessageException):
+    """A pseudo-error indicating that we want to redirect the client to a different
+    location
+
+    Attributes:
+        cookies: a list of set-cookies values to add to the response. For example:
+           b"sessionId=a3fWa; Expires=Wed, 21 Oct 2015 07:28:00 GMT"
+    """
+
+    def __init__(self, location: bytes, http_code: int = http.FOUND):
+        """
+
+        Args:
+            location: the URI to redirect to
+            http_code: the HTTP response code
+        """
+        msg = "Redirect to %s" % (location.decode("utf-8"),)
+        super().__init__(code=http_code, msg=msg)
+        self.location = location
+
+        self.cookies = []  # type: List[bytes]
+
+
 class SynapseError(CodeMessageException):
     """A base exception type for matrix errors which have an errcode and error
     message (as well as an HTTP status code).
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 8e36bc57d3..1c7c6ec0c8 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -84,8 +84,7 @@ class AdminCmdServer(HomeServer):
 
 
 class AdminCmdReplicationHandler(ReplicationClientHandler):
-    @defer.inlineCallbacks
-    def on_rdata(self, stream_name, token, rows):
+    async def on_rdata(self, stream_name, token, rows):
         pass
 
     def get_streams_to_replicate(self):
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index e82e0f11e3..2217d4a4fb 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -115,9 +115,8 @@ class ASReplicationHandler(ReplicationClientHandler):
         super(ASReplicationHandler, self).__init__(hs.get_datastore())
         self.appservice_handler = hs.get_application_service_handler()
 
-    @defer.inlineCallbacks
-    def on_rdata(self, stream_name, token, rows):
-        yield super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
+    async def on_rdata(self, stream_name, token, rows):
+        await super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
 
         if stream_name == "events":
             max_stream_id = self.store.get_room_max_stream_ordering()
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 83c436229c..38d11fdd0f 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -145,9 +145,8 @@ class FederationSenderReplicationHandler(ReplicationClientHandler):
         super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
         self.send_handler = FederationSenderHandler(hs, self)
 
-    @defer.inlineCallbacks
-    def on_rdata(self, stream_name, token, rows):
-        yield super(FederationSenderReplicationHandler, self).on_rdata(
+    async def on_rdata(self, stream_name, token, rows):
+        await super(FederationSenderReplicationHandler, self).on_rdata(
             stream_name, token, rows
         )
         self.send_handler.process_replication_rows(stream_name, token, rows)
@@ -159,6 +158,13 @@ class FederationSenderReplicationHandler(ReplicationClientHandler):
         args.update(self.send_handler.stream_positions())
         return args
 
+    def on_remote_server_up(self, server: str):
+        """Called when get a new REMOTE_SERVER_UP command."""
+
+        # Let's wake up the transaction queue for the server in case we have
+        # pending stuff to send to it.
+        self.send_handler.wake_destination(server)
+
 
 def start(config_options):
     try:
@@ -206,7 +212,7 @@ class FederationSenderHandler(object):
     to the federation sender.
     """
 
-    def __init__(self, hs, replication_client):
+    def __init__(self, hs: FederationSenderServer, replication_client):
         self.store = hs.get_datastore()
         self._is_mine_id = hs.is_mine_id
         self.federation_sender = hs.get_federation_sender()
@@ -227,6 +233,9 @@ class FederationSenderHandler(object):
             self.store.get_room_max_stream_ordering()
         )
 
+    def wake_destination(self, server: str):
+        self.federation_sender.wake_destination(server)
+
     def stream_positions(self):
         return {"federation": self.federation_position}
 
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 09e639040a..e46b6ac598 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -141,9 +141,8 @@ class PusherReplicationHandler(ReplicationClientHandler):
 
         self.pusher_pool = hs.get_pusherpool()
 
-    @defer.inlineCallbacks
-    def on_rdata(self, stream_name, token, rows):
-        yield super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
+    async def on_rdata(self, stream_name, token, rows):
+        await super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
         run_in_background(self.poke_pushers, stream_name, token, rows)
 
     @defer.inlineCallbacks
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 03031ee34d..3218da07bd 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -358,9 +358,8 @@ class SyncReplicationHandler(ReplicationClientHandler):
         self.presence_handler = hs.get_presence_handler()
         self.notifier = hs.get_notifier()
 
-    @defer.inlineCallbacks
-    def on_rdata(self, stream_name, token, rows):
-        yield super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
+    async def on_rdata(self, stream_name, token, rows):
+        await super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
         run_in_background(self.process_and_notify, stream_name, token, rows)
 
     def get_streams_to_replicate(self):
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index 1257098f92..ba536d6f04 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -172,9 +172,8 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler):
         super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore())
         self.user_directory = hs.get_user_directory_handler()
 
-    @defer.inlineCallbacks
-    def on_rdata(self, stream_name, token, rows):
-        yield super(UserDirectoryReplicationHandler, self).on_rdata(
+    async def on_rdata(self, stream_name, token, rows):
+        await super(UserDirectoryReplicationHandler, self).on_rdata(
             stream_name, token, rows
         )
         if stream_name == EventsStream.NAME:
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 35756bed87..74853f9faa 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -37,10 +37,12 @@ class EmailConfig(Config):
 
         self.email_enable_notifs = False
 
-        email_config = config.get("email", {})
+        email_config = config.get("email")
+        if email_config is None:
+            email_config = {}
 
-        self.email_smtp_host = email_config.get("smtp_host", None)
-        self.email_smtp_port = email_config.get("smtp_port", None)
+        self.email_smtp_host = email_config.get("smtp_host", "localhost")
+        self.email_smtp_port = email_config.get("smtp_port", 25)
         self.email_smtp_user = email_config.get("smtp_user", None)
         self.email_smtp_pass = email_config.get("smtp_pass", None)
         self.require_transport_security = email_config.get(
@@ -74,9 +76,9 @@ class EmailConfig(Config):
         self.email_template_dir = os.path.abspath(template_dir)
 
         self.email_enable_notifs = email_config.get("enable_notifs", False)
-        account_validity_renewal_enabled = config.get("account_validity", {}).get(
-            "renew_at"
-        )
+
+        account_validity_config = config.get("account_validity") or {}
+        account_validity_renewal_enabled = account_validity_config.get("renew_at")
 
         self.threepid_behaviour_email = (
             # Have Synapse handle the email sending if account_threepid_delegates.email
@@ -278,7 +280,9 @@ class EmailConfig(Config):
             self.email_notif_for_new_users = email_config.get(
                 "notif_for_new_users", True
             )
-            self.email_riot_base_url = email_config.get("riot_base_url", None)
+            self.email_riot_base_url = email_config.get(
+                "client_base_url", email_config.get("riot_base_url", None)
+            )
 
         if account_validity_renewal_enabled:
             self.email_expiry_template_html = email_config.get(
@@ -294,107 +298,111 @@ class EmailConfig(Config):
                     raise ConfigError("Unable to find email template file %s" % (p,))
 
     def generate_config_section(self, config_dir_path, server_name, **kwargs):
-        return """
-        # Enable sending emails for password resets, notification events or
-        # account expiry notices
-        #
-        # If your SMTP server requires authentication, the optional smtp_user &
-        # smtp_pass variables should be used
-        #
-        #email:
-        #   enable_notifs: false
-        #   smtp_host: "localhost"
-        #   smtp_port: 25 # SSL: 465, STARTTLS: 587
-        #   smtp_user: "exampleusername"
-        #   smtp_pass: "examplepassword"
-        #   require_transport_security: false
-        #
-        #   # notif_from defines the "From" address to use when sending emails.
-        #   # It must be set if email sending is enabled.
-        #   #
-        #   # The placeholder '%(app)s' will be replaced by the application name,
-        #   # which is normally 'app_name' (below), but may be overridden by the
-        #   # Matrix client application.
-        #   #
-        #   # Note that the placeholder must be written '%(app)s', including the
-        #   # trailing 's'.
-        #   #
-        #   notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>"
-        #
-        #   # app_name defines the default value for '%(app)s' in notif_from. It
-        #   # defaults to 'Matrix'.
-        #   #
-        #   #app_name: my_branded_matrix_server
-        #
-        #   # Enable email notifications by default
-        #   #
-        #   notif_for_new_users: true
-        #
-        #   # Defining a custom URL for Riot is only needed if email notifications
-        #   # should contain links to a self-hosted installation of Riot; when set
-        #   # the "app_name" setting is ignored
-        #   #
-        #   riot_base_url: "http://localhost/riot"
-        #
-        #   # Configure the time that a validation email or text message code
-        #   # will expire after sending
-        #   #
-        #   # This is currently used for password resets
-        #   #
-        #   #validation_token_lifetime: 1h
-        #
-        #   # Template directory. All template files should be stored within this
-        #   # directory. If not set, default templates from within the Synapse
-        #   # package will be used
-        #   #
-        #   # For the list of default templates, please see
-        #   # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
-        #   #
-        #   #template_dir: res/templates
-        #
-        #   # Templates for email notifications
-        #   #
-        #   notif_template_html: notif_mail.html
-        #   notif_template_text: notif_mail.txt
-        #
-        #   # Templates for account expiry notices
-        #   #
-        #   expiry_template_html: notice_expiry.html
-        #   expiry_template_text: notice_expiry.txt
-        #
-        #   # Templates for password reset emails sent by the homeserver
-        #   #
-        #   #password_reset_template_html: password_reset.html
-        #   #password_reset_template_text: password_reset.txt
-        #
-        #   # Templates for registration emails sent by the homeserver
-        #   #
-        #   #registration_template_html: registration.html
-        #   #registration_template_text: registration.txt
-        #
-        #   # Templates for validation emails sent by the homeserver when adding an email to
-        #   # your user account
-        #   #
-        #   #add_threepid_template_html: add_threepid.html
-        #   #add_threepid_template_text: add_threepid.txt
-        #
-        #   # Templates for password reset success and failure pages that a user
-        #   # will see after attempting to reset their password
-        #   #
-        #   #password_reset_template_success_html: password_reset_success.html
-        #   #password_reset_template_failure_html: password_reset_failure.html
-        #
-        #   # Templates for registration success and failure pages that a user
-        #   # will see after attempting to register using an email or phone
-        #   #
-        #   #registration_template_success_html: registration_success.html
-        #   #registration_template_failure_html: registration_failure.html
+        return """\
+        # Configuration for sending emails from Synapse.
         #
-        #   # Templates for success and failure pages that a user will see after attempting
-        #   # to add an email or phone to their account
-        #   #
-        #   #add_threepid_success_html: add_threepid_success.html
-        #   #add_threepid_failure_html: add_threepid_failure.html
+        email:
+          # The hostname of the outgoing SMTP server to use. Defaults to 'localhost'.
+          #
+          #smtp_host: mail.server
+
+          # The port on the mail server for outgoing SMTP. Defaults to 25.
+          #
+          #smtp_port: 587
+
+          # Username/password for authentication to the SMTP server. By default, no
+          # authentication is attempted.
+          #
+          # smtp_user: "exampleusername"
+          # smtp_pass: "examplepassword"
+
+          # Uncomment the following to require TLS transport security for SMTP.
+          # By default, Synapse will connect over plain text, and will then switch to
+          # TLS via STARTTLS *if the SMTP server supports it*. If this option is set,
+          # Synapse will refuse to connect unless the server supports STARTTLS.
+          #
+          #require_transport_security: true
+
+          # Enable sending emails for messages that the user has missed
+          #
+          #enable_notifs: false
+
+          # notif_from defines the "From" address to use when sending emails.
+          # It must be set if email sending is enabled.
+          #
+          # The placeholder '%(app)s' will be replaced by the application name,
+          # which is normally 'app_name' (below), but may be overridden by the
+          # Matrix client application.
+          #
+          # Note that the placeholder must be written '%(app)s', including the
+          # trailing 's'.
+          #
+          #notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>"
+
+          # app_name defines the default value for '%(app)s' in notif_from. It
+          # defaults to 'Matrix'.
+          #
+          #app_name: my_branded_matrix_server
+
+          # Uncomment the following to disable automatic subscription to email
+          # notifications for new users. Enabled by default.
+          #
+          #notif_for_new_users: false
+
+          # Custom URL for client links within the email notifications. By default
+          # links will be based on "https://matrix.to".
+          #
+          # (This setting used to be called riot_base_url; the old name is still
+          # supported for backwards-compatibility but is now deprecated.)
+          #
+          #client_base_url: "http://localhost/riot"
+
+          # Configure the time that a validation email will expire after sending.
+          # Defaults to 1h.
+          #
+          #validation_token_lifetime: 15m
+
+          # Directory in which Synapse will try to find the template files below.
+          # If not set, default templates from within the Synapse package will be used.
+          #
+          # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
+          # If you *do* uncomment it, you will need to make sure that all the templates
+          # below are in the directory.
+          #
+          # Synapse will look for the following templates in this directory:
+          #
+          # * The contents of email notifications of missed events: 'notif_mail.html' and
+          #   'notif_mail.txt'.
+          #
+          # * The contents of account expiry notice emails: 'notice_expiry.html' and
+          #   'notice_expiry.txt'.
+          #
+          # * The contents of password reset emails sent by the homeserver:
+          #   'password_reset.html' and 'password_reset.txt'
+          #
+          # * HTML pages for success and failure that a user will see when they follow
+          #   the link in the password reset email: 'password_reset_success.html' and
+          #   'password_reset_failure.html'
+          #
+          # * The contents of address verification emails sent during registration:
+          #   'registration.html' and 'registration.txt'
+          #
+          # * HTML pages for success and failure that a user will see when they follow
+          #   the link in an address verification email sent during registration:
+          #   'registration_success.html' and 'registration_failure.html'
+          #
+          # * The contents of address verification emails sent when an address is added
+          #   to a Matrix account: 'add_threepid.html' and 'add_threepid.txt'
+          #
+          # * HTML pages for success and failure that a user will see when they follow
+          #   the link in an address verification email sent when an address is added
+          #   to a Matrix account: 'add_threepid_success.html' and
+          #   'add_threepid_failure.html'
+          #
+          # You can see the default templates at:
+          # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
+          #
+          #template_dir: "res/templates"
         """
 
 
diff --git a/synapse/config/push.py b/synapse/config/push.py
index 0910958649..6f2b3a7faa 100644
--- a/synapse/config/push.py
+++ b/synapse/config/push.py
@@ -35,7 +35,7 @@ class PushConfig(Config):
 
         # Now check for the one in the 'email' section and honour it,
         # with a warning.
-        push_config = config.get("email", {})
+        push_config = config.get("email") or {}
         redact_content = push_config.get("redact_content")
         if redact_content is not None:
             print(
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index ee9614c5f7..b873995a49 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -27,6 +27,8 @@ class AccountValidityConfig(Config):
     section = "accountvalidity"
 
     def __init__(self, config, synapse_config):
+        if config is None:
+            return
         self.enabled = config.get("enabled", False)
         self.renew_by_email_enabled = "renew_at" in config
 
@@ -159,23 +161,6 @@ class RegistrationConfig(Config):
         # Optional account validity configuration. This allows for accounts to be denied
         # any request after a given period.
         #
-        # ``enabled`` defines whether the account validity feature is enabled. Defaults
-        # to False.
-        #
-        # ``period`` allows setting the period after which an account is valid
-        # after its registration. When renewing the account, its validity period
-        # will be extended by this amount of time. This parameter is required when using
-        # the account validity feature.
-        #
-        # ``renew_at`` is the amount of time before an account's expiry date at which
-        # Synapse will send an email to the account's email address with a renewal link.
-        # This needs the ``email`` and ``public_baseurl`` configuration sections to be
-        # filled.
-        #
-        # ``renew_email_subject`` is the subject of the email sent out with the renewal
-        # link. ``%%(app)s`` can be used as a placeholder for the ``app_name`` parameter
-        # from the ``email`` section.
-        #
         # Once this feature is enabled, Synapse will look for registered users without an
         # expiration date at startup and will add one to every account it found using the
         # current settings at that time.
@@ -186,21 +171,55 @@ class RegistrationConfig(Config):
         # date will be randomly selected within a range [now + period - d ; now + period],
         # where d is equal to 10%% of the validity period.
         #
-        #account_validity:
-        #  enabled: true
-        #  period: 6w
-        #  renew_at: 1w
-        #  renew_email_subject: "Renew your %%(app)s account"
-        #  # Directory in which Synapse will try to find the HTML files to serve to the
-        #  # user when trying to renew an account. Optional, defaults to
-        #  # synapse/res/templates.
-        #  template_dir: "res/templates"
-        #  # HTML to be displayed to the user after they successfully renewed their
-        #  # account. Optional.
-        #  account_renewed_html_path: "account_renewed.html"
-        #  # HTML to be displayed when the user tries to renew an account with an invalid
-        #  # renewal token. Optional.
-        #  invalid_token_html_path: "invalid_token.html"
+        account_validity:
+          # The account validity feature is disabled by default. Uncomment the
+          # following line to enable it.
+          #
+          #enabled: true
+
+          # The period after which an account is valid after its registration. When
+          # renewing the account, its validity period will be extended by this amount
+          # of time. This parameter is required when using the account validity
+          # feature.
+          #
+          #period: 6w
+
+          # The amount of time before an account's expiry date at which Synapse will
+          # send an email to the account's email address with a renewal link. By
+          # default, no such emails are sent.
+          #
+          # If you enable this setting, you will also need to fill out the 'email' and
+          # 'public_baseurl' configuration sections.
+          #
+          #renew_at: 1w
+
+          # The subject of the email sent out with the renewal link. '%%(app)s' can be
+          # used as a placeholder for the 'app_name' parameter from the 'email'
+          # section.
+          #
+          # Note that the placeholder must be written '%%(app)s', including the
+          # trailing 's'.
+          #
+          # If this is not set, a default value is used.
+          #
+          #renew_email_subject: "Renew your %%(app)s account"
+
+          # Directory in which Synapse will try to find templates for the HTML files to
+          # serve to the user when trying to renew an account. If not set, default
+          # templates from within the Synapse package will be used.
+          #
+          #template_dir: "res/templates"
+
+          # File within 'template_dir' giving the HTML to be displayed to the user after
+          # they successfully renewed their account. If not set, default text is used.
+          #
+          #account_renewed_html_path: "account_renewed.html"
+
+          # File within 'template_dir' giving the HTML to be displayed when the user
+          # tries to renew an account with an invalid renewal token. If not set,
+          # default text is used.
+          #
+          #invalid_token_html_path: "invalid_token.html"
 
         # Time that a user's session remains valid for, after they log in.
         #
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index b91414aa35..423c158b11 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -121,6 +121,7 @@ class SAML2Config(Config):
         required_methods = [
             "get_saml_attributes",
             "saml_response_to_user_attributes",
+            "get_remote_user_id",
         ]
         missing_methods = [
             method
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index a44baea365..9ea85e93e6 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Dict, Optional, Tuple, Union
+from typing import Optional, Union
 
 from six import iteritems
 
@@ -23,6 +23,7 @@ from twisted.internet import defer
 
 from synapse.appservice import ApplicationService
 from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.types import StateMap
 
 
 @attr.s(slots=True)
@@ -106,13 +107,11 @@ class EventContext:
     _state_group = attr.ib(default=None, type=Optional[int])
     state_group_before_event = attr.ib(default=None, type=Optional[int])
     prev_group = attr.ib(default=None, type=Optional[int])
-    delta_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]])
+    delta_ids = attr.ib(default=None, type=Optional[StateMap[str]])
     app_service = attr.ib(default=None, type=Optional[ApplicationService])
 
-    _current_state_ids = attr.ib(
-        default=None, type=Optional[Dict[Tuple[str, str], str]]
-    )
-    _prev_state_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]])
+    _current_state_ids = attr.ib(default=None, type=Optional[StateMap[str]])
+    _prev_state_ids = attr.ib(default=None, type=Optional[StateMap[str]])
 
     @staticmethod
     def with_state(
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index ced4925a98..174f6e42be 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -259,7 +259,9 @@ class FederationRemoteSendQueue(object):
     def federation_ack(self, token):
         self._clear_queue_before_pos(token)
 
-    def get_replication_rows(self, from_token, to_token, limit, federation_ack=None):
+    async def get_replication_rows(
+        self, from_token, to_token, limit, federation_ack=None
+    ):
         """Get rows to be sent over federation between the two tokens
 
         Args:
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 4ebb0e8bc0..36c83c3027 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -21,6 +21,7 @@ from prometheus_client import Counter
 
 from twisted.internet import defer
 
+import synapse
 import synapse.metrics
 from synapse.federation.sender.per_destination_queue import PerDestinationQueue
 from synapse.federation.sender.transaction_manager import TransactionManager
@@ -54,7 +55,7 @@ sent_pdus_destination_dist_total = Counter(
 
 
 class FederationSender(object):
-    def __init__(self, hs):
+    def __init__(self, hs: "synapse.server.HomeServer"):
         self.hs = hs
         self.server_name = hs.hostname
 
@@ -482,7 +483,20 @@ class FederationSender(object):
 
     def send_device_messages(self, destination):
         if destination == self.server_name:
-            logger.info("Not sending device update to ourselves")
+            logger.warning("Not sending device update to ourselves")
+            return
+
+        self._get_per_destination_queue(destination).attempt_new_transaction()
+
+    def wake_destination(self, destination: str):
+        """Called when we want to retry sending transactions to a remote.
+
+        This is mainly useful if the remote server has been down and we think it
+        might have come back.
+        """
+
+        if destination == self.server_name:
+            logger.warning("Not waking up ourselves")
             return
 
         self._get_per_destination_queue(destination).attempt_new_transaction()
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index a5b36b1827..5012aaea35 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -31,6 +31,7 @@ from synapse.handlers.presence import format_user_presence_state
 from synapse.metrics import sent_transactions_counter
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.presence import UserPresenceState
+from synapse.types import StateMap
 from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
 
 # This is defined in the Matrix spec and enforced by the receiver.
@@ -77,7 +78,7 @@ class PerDestinationQueue(object):
         # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
         # based on their key (e.g. typing events by room_id)
         # Map of (edu_type, key) -> Edu
-        self._pending_edus_keyed = {}  # type: dict[tuple[str, str], Edu]
+        self._pending_edus_keyed = {}  # type: StateMap[Edu]
 
         # Map of user_id -> UserPresenceState of pending presence to be sent to this
         # destination
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index b4cbf23394..d8cf9ed299 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -44,6 +44,7 @@ from synapse.logging.opentracing import (
     tags,
     whitelisted_homeserver,
 )
+from synapse.server import HomeServer
 from synapse.types import ThirdPartyInstanceID, get_domain_from_id
 from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.versionstring import get_version_string
@@ -101,12 +102,17 @@ class NoAuthenticationError(AuthenticationError):
 
 
 class Authenticator(object):
-    def __init__(self, hs):
+    def __init__(self, hs: HomeServer):
         self._clock = hs.get_clock()
         self.keyring = hs.get_keyring()
         self.server_name = hs.hostname
         self.store = hs.get_datastore()
         self.federation_domain_whitelist = hs.config.federation_domain_whitelist
+        self.notifer = hs.get_notifier()
+
+        self.replication_client = None
+        if hs.config.worker.worker_app:
+            self.replication_client = hs.get_tcp_replication()
 
     # A method just so we can pass 'self' as the authenticator to the Servlets
     async def authenticate_request(self, request, content):
@@ -166,6 +172,17 @@ class Authenticator(object):
         try:
             logger.info("Marking origin %r as up", origin)
             await self.store.set_destination_retry_timings(origin, None, 0, 0)
+
+            # Inform the relevant places that the remote server is back up.
+            self.notifer.notify_remote_server_up(origin)
+            if self.replication_client:
+                # If we're on a worker we try and inform master about this. The
+                # replication client doesn't hook into the notifier to avoid
+                # infinite loops where we send a `REMOTE_SERVER_UP` command to
+                # master, which then echoes it back to us which in turn pokes
+                # the notifier.
+                self.replication_client.send_remote_server_up(origin)
+
         except Exception:
             logger.exception("Error resetting retry timings on %s", origin)
 
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 76d18a8ba8..60a7c938bc 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -14,9 +14,11 @@
 # limitations under the License.
 
 import logging
+from typing import List
 
 from synapse.api.constants import Membership
-from synapse.types import RoomStreamToken
+from synapse.events import FrozenEvent
+from synapse.types import RoomStreamToken, StateMap
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
@@ -134,7 +136,7 @@ class AdminHandler(BaseHandler):
             The returned value is that returned by `writer.finished()`.
         """
         # Get all rooms the user is in or has been in
-        rooms = await self.store.get_rooms_for_user_where_membership_is(
+        rooms = await self.store.get_rooms_for_local_user_where_membership_is(
             user_id,
             membership_list=(
                 Membership.JOIN,
@@ -259,35 +261,26 @@ class ExfiltrationWriter(object):
     """Interface used to specify how to write exported data.
     """
 
-    def write_events(self, room_id, events):
+    def write_events(self, room_id: str, events: List[FrozenEvent]):
         """Write a batch of events for a room.
-
-        Args:
-            room_id (str)
-            events (list[FrozenEvent])
         """
         pass
 
-    def write_state(self, room_id, event_id, state):
+    def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]):
         """Write the state at the given event in the room.
 
         This only gets called for backward extremities rather than for each
         event.
-
-        Args:
-            room_id (str)
-            event_id (str)
-            state (dict[tuple[str, str], FrozenEvent])
         """
         pass
 
-    def write_invite(self, room_id, event, state):
+    def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]):
         """Write an invite for the room, with associated invite state.
 
         Args:
-            room_id (str)
-            event (FrozenEvent)
-            state (dict[tuple[str, str], dict]): A subset of the state at the
+            room_id
+            event
+            state: A subset of the state at the
                 invite, with a subset of the event keys (type, state_key
                 content and sender)
         """
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 4426967f88..2afb390a92 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -140,7 +140,7 @@ class DeactivateAccountHandler(BaseHandler):
             user_id (str): The user ID to reject pending invites for.
         """
         user = UserID.from_string(user_id)
-        pending_invites = await self.store.get_invited_rooms_for_user(user_id)
+        pending_invites = await self.store.get_invited_rooms_for_local_user(user_id)
 
         for room in pending_invites:
             try:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 61b6713c88..d4f9a792fc 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -64,7 +64,7 @@ from synapse.replication.http.federation import (
 from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
 from synapse.state import StateResolutionStore, resolve_events_with_store
 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
-from synapse.types import UserID, get_domain_from_id
+from synapse.types import StateMap, UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.distributor import user_joined_room
 from synapse.util.retryutils import NotRetryingDestination
@@ -89,7 +89,7 @@ class _NewEventInfo:
 
     event = attr.ib(type=EventBase)
     state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
-    auth_events = attr.ib(type=Optional[Dict[Tuple[str, str], EventBase]], default=None)
+    auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
 
 
 def shortstr(iterable, maxitems=5):
@@ -352,9 +352,7 @@ class FederationHandler(BaseHandler):
                     ours = await self.state_store.get_state_groups_ids(room_id, seen)
 
                     # state_maps is a list of mappings from (type, state_key) to event_id
-                    state_maps = list(
-                        ours.values()
-                    )  # type: list[dict[tuple[str, str], str]]
+                    state_maps = list(ours.values())  # type: list[StateMap[str]]
 
                     # we don't need this any more, let's delete it.
                     del ours
@@ -1912,7 +1910,7 @@ class FederationHandler(BaseHandler):
         origin: str,
         event: EventBase,
         state: Optional[Iterable[EventBase]],
-        auth_events: Optional[Dict[Tuple[str, str], EventBase]],
+        auth_events: Optional[StateMap[EventBase]],
         backfilled: bool,
     ):
         """
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 44ec3e66ae..2e6755f19c 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -101,7 +101,7 @@ class InitialSyncHandler(BaseHandler):
         if include_archived:
             memberships.append(Membership.LEAVE)
 
-        room_list = await self.store.get_rooms_for_user_where_membership_is(
+        room_list = await self.store.get_rooms_for_local_user_where_membership_is(
             user_id=user_id, membership_list=memberships
         )
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 9cab2adbfb..9f50196ea7 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -32,7 +32,15 @@ from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, Syna
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.http.endpoint import parse_and_validate_server_name
 from synapse.storage.state import StateFilter
-from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
+from synapse.types import (
+    Requester,
+    RoomAlias,
+    RoomID,
+    RoomStreamToken,
+    StateMap,
+    StreamToken,
+    UserID,
+)
 from synapse.util import stringutils
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.response_cache import ResponseCache
@@ -207,15 +215,19 @@ class RoomCreationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def _update_upgraded_room_pls(
-        self, requester, old_room_id, new_room_id, old_room_state,
+        self,
+        requester: Requester,
+        old_room_id: str,
+        new_room_id: str,
+        old_room_state: StateMap[str],
     ):
         """Send updated power levels in both rooms after an upgrade
 
         Args:
-            requester (synapse.types.Requester): the user requesting the upgrade
-            old_room_id (str): the id of the room to be replaced
-            new_room_id (str): the id of the replacement room
-            old_room_state (dict[tuple[str, str], str]): the state map for the old room
+            requester: the user requesting the upgrade
+            old_room_id: the id of the room to be replaced
+            new_room_id: the id of the replacement room
+            old_room_state: the state map for the old room
 
         Returns:
             Deferred
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 03bb52ccfb..15e8aa5249 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -690,7 +690,7 @@ class RoomMemberHandler(object):
 
     @defer.inlineCallbacks
     def _get_inviter(self, user_id, room_id):
-        invite = yield self.store.get_invite_for_user_in_room(
+        invite = yield self.store.get_invite_for_local_user_in_room(
             user_id=user_id, room_id=room_id
         )
         if invite:
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 107f97032b..7f411b53b9 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -32,6 +32,7 @@ from synapse.types import (
     mxid_localpart_allowed_characters,
 )
 from synapse.util.async_helpers import Linearizer
+from synapse.util.iterutils import chunk_seq
 
 logger = logging.getLogger(__name__)
 
@@ -132,17 +133,28 @@ class SamlHandler:
             logger.warning("SAML2 response was not signed")
             raise SynapseError(400, "SAML2 response was not signed")
 
-        logger.info("SAML2 response: %s", saml2_auth.origxml)
-        logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)
+        logger.debug("SAML2 response: %s", saml2_auth.origxml)
+        for assertion in saml2_auth.assertions:
+            # kibana limits the length of a log field, whereas this is all rather
+            # useful, so split it up.
+            count = 0
+            for part in chunk_seq(str(assertion), 10000):
+                logger.info(
+                    "SAML2 assertion: %s%s", "(%i)..." % (count,) if count else "", part
+                )
+                count += 1
 
-        try:
-            remote_user_id = saml2_auth.ava["uid"][0]
-        except KeyError:
-            logger.warning("SAML2 response lacks a 'uid' attestation")
-            raise SynapseError(400, "'uid' not in SAML2 response")
+        logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)
 
         self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
 
+        remote_user_id = self._user_mapping_provider.get_remote_user_id(
+            saml2_auth, client_redirect_url
+        )
+
+        if not remote_user_id:
+            raise Exception("Failed to extract remote user id from SAML response")
+
         with (await self._mapping_lock.queue(self._auth_provider_id)):
             # first of all, check if we already have a mapping for this user
             logger.info(
@@ -279,6 +291,20 @@ class DefaultSamlMappingProvider(object):
         self._mxid_source_attribute = parsed_config.mxid_source_attribute
         self._mxid_mapper = parsed_config.mxid_mapper
 
+        self._grandfathered_mxid_source_attribute = (
+            module_api._hs.config.saml2_grandfathered_mxid_source_attribute
+        )
+
+    def get_remote_user_id(
+        self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
+    ):
+        """Extracts the remote user id from the SAML response"""
+        try:
+            return saml_response.ava["uid"][0]
+        except KeyError:
+            logger.warning("SAML2 response lacks a 'uid' attestation")
+            raise SynapseError(400, "'uid' not in SAML2 response")
+
     def saml_response_to_user_attributes(
         self,
         saml_response: saml2.response.AuthnResponse,
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index ef750d1497..110097eab9 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -179,7 +179,7 @@ class SearchHandler(BaseHandler):
         search_filter = Filter(filter_dict)
 
         # TODO: Search through left rooms too
-        rooms = yield self.store.get_rooms_for_user_where_membership_is(
+        rooms = yield self.store.get_rooms_for_local_user_where_membership_is(
             user.to_string(),
             membership_list=[Membership.JOIN],
             # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 2d3b8ba73c..cd95f85e3f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1662,7 +1662,7 @@ class SyncHandler(object):
             Membership.BAN,
         )
 
-        room_list = await self.store.get_rooms_for_user_where_membership_is(
+        room_list = await self.store.get_rooms_for_local_user_where_membership_is(
             user_id=user_id, membership_list=membership_list
         )
 
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index b635c339ed..d5ca9cb07b 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -257,7 +257,7 @@ class TypingHandler(object):
             "typing_key", self._latest_room_serial, rooms=[member.room_id]
         )
 
-    def get_all_typing_updates(self, last_id, current_id):
+    async def get_all_typing_updates(self, last_id, current_id):
         if last_id == current_id:
             return []
 
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 943d12c907..04bc2385a2 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -14,8 +14,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import cgi
 import collections
+import html
 import http.client
 import logging
 import types
@@ -36,6 +36,7 @@ import synapse.metrics
 from synapse.api.errors import (
     CodeMessageException,
     Codes,
+    RedirectException,
     SynapseError,
     UnrecognizedRequestError,
 )
@@ -153,14 +154,18 @@ def _return_html_error(f, request):
 
     Args:
         f (twisted.python.failure.Failure):
-        request (twisted.web.iweb.IRequest):
+        request (twisted.web.server.Request):
     """
     if f.check(CodeMessageException):
         cme = f.value
         code = cme.code
         msg = cme.msg
 
-        if isinstance(cme, SynapseError):
+        if isinstance(cme, RedirectException):
+            logger.info("%s redirect to %s", request, cme.location)
+            request.setHeader(b"location", cme.location)
+            request.cookies.extend(cme.cookies)
+        elif isinstance(cme, SynapseError):
             logger.info("%s SynapseError: %s - %s", request, code, msg)
         else:
             logger.error(
@@ -178,7 +183,7 @@ def _return_html_error(f, request):
             exc_info=(f.type, f.value, f.getTracebackObject()),
         )
 
-    body = HTML_ERROR_TEMPLATE.format(code=code, msg=cgi.escape(msg)).encode("utf-8")
+    body = HTML_ERROR_TEMPLATE.format(code=code, msg=html.escape(msg)).encode("utf-8")
     request.setResponseCode(code)
     request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
     request.setHeader(b"Content-Length", b"%i" % (len(body),))
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 305b9b0178..d680ee95e1 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2017 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,18 +17,26 @@ import logging
 
 from twisted.internet import defer
 
+from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.types import UserID
 
+"""
+This package defines the 'stable' API which can be used by extension modules which
+are loaded into Synapse.
+"""
+
+__all__ = ["errors", "make_deferred_yieldable", "run_in_background", "ModuleApi"]
+
 logger = logging.getLogger(__name__)
 
 
 class ModuleApi(object):
-    """A proxy object that gets passed to password auth providers so they
+    """A proxy object that gets passed to various plugin modules so they
     can register new users etc if necessary.
     """
 
     def __init__(self, hs, auth_handler):
-        self.hs = hs
+        self._hs = hs
 
         self._store = hs.get_datastore()
         self._auth = hs.get_auth()
@@ -64,7 +73,7 @@ class ModuleApi(object):
         """
         if username.startswith("@"):
             return username
-        return UserID(username, self.hs.hostname).to_string()
+        return UserID(username, self._hs.hostname).to_string()
 
     def check_user_exists(self, user_id):
         """Check if user exists.
@@ -111,10 +120,14 @@ class ModuleApi(object):
             displayname (str|None): The displayname of the new user.
             emails (List[str]): Emails to bind to the new user.
 
+        Raises:
+            SynapseError if there is an error performing the registration. Check the
+                'errcode' property for more information on the reason for failure
+
         Returns:
             Deferred[str]: user_id
         """
-        return self.hs.get_registration_handler().register_user(
+        return self._hs.get_registration_handler().register_user(
             localpart=localpart, default_display_name=displayname, bind_emails=emails
         )
 
@@ -131,12 +144,34 @@ class ModuleApi(object):
         Returns:
             defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
         """
-        return self.hs.get_registration_handler().register_device(
+        return self._hs.get_registration_handler().register_device(
             user_id=user_id,
             device_id=device_id,
             initial_display_name=initial_display_name,
         )
 
+    def record_user_external_id(
+        self, auth_provider_id: str, remote_user_id: str, registered_user_id: str
+    ) -> defer.Deferred:
+        """Record a mapping from an external user id to a mxid
+
+        Args:
+            auth_provider: identifier for the remote auth provider
+            external_id: id on that system
+            user_id: complete mxid that it is mapped to
+        """
+        return self._store.record_user_external_id(
+            auth_provider_id, remote_user_id, registered_user_id
+        )
+
+    def generate_short_term_login_token(
+        self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+    ) -> str:
+        """Generate a login token suitable for m.login.token authentication"""
+        return self._hs.get_macaroon_generator().generate_short_term_login_token(
+            user_id, duration_in_ms
+        )
+
     @defer.inlineCallbacks
     def invalidate_access_token(self, access_token):
         """Invalidate an access token for a user
@@ -157,7 +192,7 @@ class ModuleApi(object):
         user_id = user_info["user"].to_string()
         if device_id:
             # delete the device, which will also delete its access tokens
-            yield self.hs.get_device_handler().delete_device(user_id, device_id)
+            yield self._hs.get_device_handler().delete_device(user_id, device_id)
         else:
             # no associated device. Just delete the access token.
             yield self._auth_handler.delete_access_token(access_token)
diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py
new file mode 100644
index 0000000000..b15441772c
--- /dev/null
+++ b/synapse/module_api/errors.py
@@ -0,0 +1,18 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Exception types which are exposed as part of the stable module API"""
+
+from synapse.api.errors import RedirectException, SynapseError  # noqa: F401
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 5f5f765bea..6132727cbd 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -15,11 +15,13 @@
 
 import logging
 from collections import namedtuple
+from typing import Callable, List
 
 from prometheus_client import Counter
 
 from twisted.internet import defer
 
+import synapse.server
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import AuthError
 from synapse.handlers.presence import format_user_presence_state
@@ -154,7 +156,7 @@ class Notifier(object):
 
     UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
 
-    def __init__(self, hs):
+    def __init__(self, hs: "synapse.server.HomeServer"):
         self.user_to_user_stream = {}
         self.room_to_user_streams = {}
 
@@ -164,7 +166,12 @@ class Notifier(object):
         self.store = hs.get_datastore()
         self.pending_new_room_events = []
 
-        self.replication_callbacks = []
+        # Called when there are new things to stream over replication
+        self.replication_callbacks = []  # type: List[Callable[[], None]]
+
+        # Called when remote servers have come back online after having been
+        # down.
+        self.remote_server_up_callbacks = []  # type: List[Callable[[str], None]]
 
         self.clock = hs.get_clock()
         self.appservice_handler = hs.get_application_service_handler()
@@ -205,7 +212,7 @@ class Notifier(object):
             "synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream)
         )
 
-    def add_replication_callback(self, cb):
+    def add_replication_callback(self, cb: Callable[[], None]):
         """Add a callback that will be called when some new data is available.
         Callback is not given any arguments. It should *not* return a Deferred - if
         it needs to do any asynchronous work, a background thread should be started and
@@ -213,6 +220,12 @@ class Notifier(object):
         """
         self.replication_callbacks.append(cb)
 
+    def add_remote_server_up_callback(self, cb: Callable[[str], None]):
+        """Add a callback that will be called when synapse detects a server
+        has been
+        """
+        self.remote_server_up_callbacks.append(cb)
+
     def on_new_room_event(
         self, event, room_stream_id, max_room_stream_id, extra_users=[]
     ):
@@ -522,3 +535,15 @@ class Notifier(object):
         """Notify the any replication listeners that there's a new event"""
         for cb in self.replication_callbacks:
             cb()
+
+    def notify_remote_server_up(self, server: str):
+        """Notify any replication that a remote server has come back up
+        """
+        # We call federation_sender directly rather than registering as a
+        # callback as a) we already have a reference to it and b) it introduces
+        # circular dependencies.
+        if self.federation_sender:
+            self.federation_sender.wake_destination(server)
+
+        for cb in self.remote_server_up_callbacks:
+            cb(server)
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index de5c101a58..5dae4648c0 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -21,7 +21,7 @@ from synapse.storage import Storage
 
 @defer.inlineCallbacks
 def get_badge_count(store, user_id):
-    invites = yield store.get_invited_rooms_for_user(user_id)
+    invites = yield store.get_invited_rooms_for_local_user(user_id)
     joins = yield store.get_rooms_for_user(user_id)
 
     my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read")
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 29f35b9915..3aa6cb8b96 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -152,7 +152,7 @@ class SlavedEventStore(
 
         if etype == EventTypes.Member:
             self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
-            self.get_invited_rooms_for_user.invalidate((state_key,))
+            self.get_invited_rooms_for_local_user.invalidate((state_key,))
 
         if relates_to:
             self.get_relations_for_event.invalidate_many((relates_to,))
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index aa7fd90e26..fc06a7b053 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -110,7 +110,7 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
         port = hs.config.worker_replication_port
         hs.get_reactor().connectTCP(host, port, self.factory)
 
-    def on_rdata(self, stream_name, token, rows):
+    async def on_rdata(self, stream_name, token, rows):
         """Called to handle a batch of replication data with a given stream token.
 
         By default this just pokes the slave store. Can be overridden in subclasses to
@@ -121,20 +121,17 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
             token (int): stream token for this batch of rows
             rows (list): a list of Stream.ROW_TYPE objects as returned by
                 Stream.parse_row.
-
-        Returns:
-            Deferred|None
         """
         logger.debug("Received rdata %s -> %s", stream_name, token)
-        return self.store.process_replication_rows(stream_name, token, rows)
+        self.store.process_replication_rows(stream_name, token, rows)
 
-    def on_position(self, stream_name, token):
+    async def on_position(self, stream_name, token):
         """Called when we get new position data. By default this just pokes
         the slave store.
 
         Can be overriden in subclasses to handle more.
         """
-        return self.store.process_replication_rows(stream_name, token, [])
+        self.store.process_replication_rows(stream_name, token, [])
 
     def on_sync(self, data):
         """When we received a SYNC we wake up any deferreds that were waiting
@@ -146,6 +143,9 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
         if d:
             d.callback(data)
 
+    def on_remote_server_up(self, server: str):
+        """Called when get a new REMOTE_SERVER_UP command."""
+
     def get_streams_to_replicate(self) -> Dict[str, int]:
         """Called when a new connection has been established and we need to
         subscribe to streams.
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index cbb36b9acf..451671412d 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -387,6 +387,20 @@ class UserIpCommand(Command):
         )
 
 
+class RemoteServerUpCommand(Command):
+    """Sent when a worker has detected that a remote server is no longer
+    "down" and retry timings should be reset.
+
+    If sent from a client the server will relay to all other workers.
+
+    Format::
+
+        REMOTE_SERVER_UP <server>
+    """
+
+    NAME = "REMOTE_SERVER_UP"
+
+
 _COMMANDS = (
     ServerCommand,
     RdataCommand,
@@ -401,6 +415,7 @@ _COMMANDS = (
     RemovePusherCommand,
     InvalidateCacheCommand,
     UserIpCommand,
+    RemoteServerUpCommand,
 )  # type: Tuple[Type[Command], ...]
 
 # Map of command name to command type.
@@ -414,6 +429,7 @@ VALID_SERVER_COMMANDS = (
     ErrorCommand.NAME,
     PingCommand.NAME,
     SyncCommand.NAME,
+    RemoteServerUpCommand.NAME,
 )
 
 # The commands the client is allowed to send
@@ -427,4 +443,5 @@ VALID_CLIENT_COMMANDS = (
     InvalidateCacheCommand.NAME,
     UserIpCommand.NAME,
     ErrorCommand.NAME,
+    RemoteServerUpCommand.NAME,
 )
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index db0353c996..131e5acb09 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -76,17 +76,17 @@ from synapse.replication.tcp.commands import (
     PingCommand,
     PositionCommand,
     RdataCommand,
+    RemoteServerUpCommand,
     ReplicateCommand,
     ServerCommand,
     SyncCommand,
     UserSyncCommand,
 )
+from synapse.replication.tcp.streams import STREAMS_MAP
 from synapse.types import Collection
 from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
-from .streams import STREAMS_MAP
-
 connection_close_counter = Counter(
     "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
 )
@@ -241,19 +241,16 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
         )
 
-    def handle_command(self, cmd):
+    async def handle_command(self, cmd: Command):
         """Handle a command we have received over the replication stream.
 
-        By default delegates to on_<COMMAND>
+        By default delegates to on_<COMMAND>, which should return an awaitable.
 
         Args:
-            cmd (synapse.replication.tcp.commands.Command): received command
-
-        Returns:
-            Deferred
+            cmd: received command
         """
         handler = getattr(self, "on_%s" % (cmd.NAME,))
-        return handler(cmd)
+        await handler(cmd)
 
     def close(self):
         logger.warning("[%s] Closing connection", self.id())
@@ -326,10 +323,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         for cmd in pending:
             self.send_command(cmd)
 
-    def on_PING(self, line):
+    async def on_PING(self, line):
         self.received_ping = True
 
-    def on_ERROR(self, cmd):
+    async def on_ERROR(self, cmd):
         logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
 
     def pauseProducing(self):
@@ -429,16 +426,16 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         BaseReplicationStreamProtocol.connectionMade(self)
         self.streamer.new_connection(self)
 
-    def on_NAME(self, cmd):
+    async def on_NAME(self, cmd):
         logger.info("[%s] Renamed to %r", self.id(), cmd.data)
         self.name = cmd.data
 
-    def on_USER_SYNC(self, cmd):
-        return self.streamer.on_user_sync(
+    async def on_USER_SYNC(self, cmd):
+        await self.streamer.on_user_sync(
             self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
         )
 
-    def on_REPLICATE(self, cmd):
+    async def on_REPLICATE(self, cmd):
         stream_name = cmd.stream_name
         token = cmd.token
 
@@ -449,23 +446,26 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
                 for stream in iterkeys(self.streamer.streams_by_name)
             ]
 
-            return make_deferred_yieldable(
+            await make_deferred_yieldable(
                 defer.gatherResults(deferreds, consumeErrors=True)
             )
         else:
-            return self.subscribe_to_stream(stream_name, token)
+            await self.subscribe_to_stream(stream_name, token)
 
-    def on_FEDERATION_ACK(self, cmd):
-        return self.streamer.federation_ack(cmd.token)
+    async def on_FEDERATION_ACK(self, cmd):
+        self.streamer.federation_ack(cmd.token)
 
-    def on_REMOVE_PUSHER(self, cmd):
-        return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
+    async def on_REMOVE_PUSHER(self, cmd):
+        await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
 
-    def on_INVALIDATE_CACHE(self, cmd):
-        return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
+    async def on_INVALIDATE_CACHE(self, cmd):
+        self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
 
-    def on_USER_IP(self, cmd):
-        return self.streamer.on_user_ip(
+    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+        self.streamer.on_remote_server_up(cmd.data)
+
+    async def on_USER_IP(self, cmd):
+        self.streamer.on_user_ip(
             cmd.user_id,
             cmd.access_token,
             cmd.ip,
@@ -474,8 +474,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
             cmd.last_seen,
         )
 
-    @defer.inlineCallbacks
-    def subscribe_to_stream(self, stream_name, token):
+    async def subscribe_to_stream(self, stream_name, token):
         """Subscribe the remote to a stream.
 
         This invloves checking if they've missed anything and sending those
@@ -487,7 +486,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
         try:
             # Get missing updates
-            updates, current_token = yield self.streamer.get_stream_updates(
+            updates, current_token = await self.streamer.get_stream_updates(
                 stream_name, token
             )
 
@@ -560,6 +559,9 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
     def send_sync(self, data):
         self.send_command(SyncCommand(data))
 
+    def send_remote_server_up(self, server: str):
+        self.send_command(RemoteServerUpCommand(server))
+
     def on_connection_closed(self):
         BaseReplicationStreamProtocol.on_connection_closed(self)
         self.streamer.lost_connection(self)
@@ -572,7 +574,7 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
     """
 
     @abc.abstractmethod
-    def on_rdata(self, stream_name, token, rows):
+    async def on_rdata(self, stream_name, token, rows):
         """Called to handle a batch of replication data with a given stream token.
 
         Args:
@@ -580,14 +582,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
             token (int): stream token for this batch of rows
             rows (list): a list of Stream.ROW_TYPE objects as returned by
                 Stream.parse_row.
-
-        Returns:
-            Deferred|None
         """
         raise NotImplementedError()
 
     @abc.abstractmethod
-    def on_position(self, stream_name, token):
+    async def on_position(self, stream_name, token):
         """Called when we get new position data."""
         raise NotImplementedError()
 
@@ -597,6 +596,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
     @abc.abstractmethod
+    async def on_remote_server_up(self, server: str):
+        """Called when get a new REMOTE_SERVER_UP command."""
+        raise NotImplementedError()
+
+    @abc.abstractmethod
     def get_streams_to_replicate(self):
         """Called when a new connection has been established and we need to
         subscribe to streams.
@@ -676,12 +680,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         if not self.streams_connecting:
             self.handler.finished_connecting()
 
-    def on_SERVER(self, cmd):
+    async def on_SERVER(self, cmd):
         if cmd.data != self.server_name:
             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
             self.send_error("Wrong remote")
 
-    def on_RDATA(self, cmd):
+    async def on_RDATA(self, cmd):
         stream_name = cmd.stream_name
         inbound_rdata_count.labels(stream_name).inc()
 
@@ -701,19 +705,22 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
             # Check if this is the last of a batch of updates
             rows = self.pending_batches.pop(stream_name, [])
             rows.append(row)
-            return self.handler.on_rdata(stream_name, cmd.token, rows)
+            await self.handler.on_rdata(stream_name, cmd.token, rows)
 
-    def on_POSITION(self, cmd):
+    async def on_POSITION(self, cmd):
         # When we get a `POSITION` command it means we've finished getting
         # missing updates for the given stream, and are now up to date.
         self.streams_connecting.discard(cmd.stream_name)
         if not self.streams_connecting:
             self.handler.finished_connecting()
 
-        return self.handler.on_position(cmd.stream_name, cmd.token)
+        await self.handler.on_position(cmd.stream_name, cmd.token)
+
+    async def on_SYNC(self, cmd):
+        self.handler.on_sync(cmd.data)
 
-    def on_SYNC(self, cmd):
-        return self.handler.on_sync(cmd.data)
+    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+        self.handler.on_remote_server_up(cmd.data)
 
     def replicate(self, stream_name, token):
         """Send the subscription request to the server
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index cbfdaf5773..6ebf944f66 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -23,7 +23,6 @@ from six import itervalues
 
 from prometheus_client import Counter
 
-from twisted.internet import defer
 from twisted.internet.protocol import Factory
 
 from synapse.metrics import LaterGauge
@@ -121,6 +120,7 @@ class ReplicationStreamer(object):
             self.federation_sender = hs.get_federation_sender()
 
         self.notifier.add_replication_callback(self.on_notifier_poke)
+        self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
 
         # Keeps track of whether we are currently checking for updates
         self.is_looping = False
@@ -155,8 +155,7 @@ class ReplicationStreamer(object):
 
         run_as_background_process("replication_notifier", self._run_notifier_loop)
 
-    @defer.inlineCallbacks
-    def _run_notifier_loop(self):
+    async def _run_notifier_loop(self):
         self.is_looping = True
 
         try:
@@ -185,7 +184,7 @@ class ReplicationStreamer(object):
                             continue
 
                         if self._replication_torture_level:
-                            yield self.clock.sleep(
+                            await self.clock.sleep(
                                 self._replication_torture_level / 1000.0
                             )
 
@@ -196,7 +195,7 @@ class ReplicationStreamer(object):
                             stream.upto_token,
                         )
                         try:
-                            updates, current_token = yield stream.get_updates()
+                            updates, current_token = await stream.get_updates()
                         except Exception:
                             logger.info("Failed to handle stream %s", stream.NAME)
                             raise
@@ -233,7 +232,7 @@ class ReplicationStreamer(object):
             self.is_looping = False
 
     @measure_func("repl.get_stream_updates")
-    def get_stream_updates(self, stream_name, token):
+    async def get_stream_updates(self, stream_name, token):
         """For a given stream get all updates since token. This is called when
         a client first subscribes to a stream.
         """
@@ -241,7 +240,7 @@ class ReplicationStreamer(object):
         if not stream:
             raise Exception("unknown stream %s", stream_name)
 
-        return stream.get_updates_since(token)
+        return await stream.get_updates_since(token)
 
     @measure_func("repl.federation_ack")
     def federation_ack(self, token):
@@ -252,22 +251,20 @@ class ReplicationStreamer(object):
             self.federation_sender.federation_ack(token)
 
     @measure_func("repl.on_user_sync")
-    @defer.inlineCallbacks
-    def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
+    async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
         """A client has started/stopped syncing on a worker.
         """
         user_sync_counter.inc()
-        yield self.presence_handler.update_external_syncs_row(
+        await self.presence_handler.update_external_syncs_row(
             conn_id, user_id, is_syncing, last_sync_ms
         )
 
     @measure_func("repl.on_remove_pusher")
-    @defer.inlineCallbacks
-    def on_remove_pusher(self, app_id, push_key, user_id):
+    async def on_remove_pusher(self, app_id, push_key, user_id):
         """A client has asked us to remove a pusher
         """
         remove_pusher_counter.inc()
-        yield self.store.delete_pusher_by_app_id_pushkey_user_id(
+        await self.store.delete_pusher_by_app_id_pushkey_user_id(
             app_id=app_id, pushkey=push_key, user_id=user_id
         )
 
@@ -281,15 +278,24 @@ class ReplicationStreamer(object):
         getattr(self.store, cache_func).invalidate(tuple(keys))
 
     @measure_func("repl.on_user_ip")
-    @defer.inlineCallbacks
-    def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+    async def on_user_ip(
+        self, user_id, access_token, ip, user_agent, device_id, last_seen
+    ):
         """The client saw a user request
         """
         user_ip_cache_counter.inc()
-        yield self.store.insert_client_ip(
+        await self.store.insert_client_ip(
             user_id, access_token, ip, user_agent, device_id, last_seen
         )
-        yield self._server_notices_sender.on_user_ip(user_id)
+        await self._server_notices_sender.on_user_ip(user_id)
+
+    @measure_func("repl.on_remote_server_up")
+    def on_remote_server_up(self, server: str):
+        self.notifier.notify_remote_server_up(server)
+
+    def send_remote_server_up(self, server: str):
+        for conn in self.connections:
+            conn.send_remote_server_up(server)
 
     def send_sync_to_all_connections(self, data):
         """Sends a SYNC command to all clients.
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 4ab0334fc1..e03e77199b 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -19,8 +19,6 @@ import logging
 from collections import namedtuple
 from typing import Any
 
-from twisted.internet import defer
-
 logger = logging.getLogger(__name__)
 
 
@@ -144,8 +142,7 @@ class Stream(object):
         self.upto_token = self.current_token()
         self.last_token = self.upto_token
 
-    @defer.inlineCallbacks
-    def get_updates(self):
+    async def get_updates(self):
         """Gets all updates since the last time this function was called (or
         since the stream was constructed if it hadn't been called before),
         until the `upto_token`
@@ -156,13 +153,12 @@ class Stream(object):
                 list of ``(token, row)`` entries. ``row`` will be json-serialised and
                 sent over the replication steam.
         """
-        updates, current_token = yield self.get_updates_since(self.last_token)
+        updates, current_token = await self.get_updates_since(self.last_token)
         self.last_token = current_token
 
         return updates, current_token
 
-    @defer.inlineCallbacks
-    def get_updates_since(self, from_token):
+    async def get_updates_since(self, from_token):
         """Like get_updates except allows specifying from when we should
         stream updates
 
@@ -182,15 +178,16 @@ class Stream(object):
         if from_token == current_token:
             return [], current_token
 
+        logger.info("get_updates_since: %s", self.__class__)
         if self._LIMITED:
-            rows = yield self.update_function(
+            rows = await self.update_function(
                 from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
             )
 
             # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
             rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
         else:
-            rows = yield self.update_function(from_token, current_token)
+            rows = await self.update_function(from_token, current_token)
 
         updates = [(row[0], row[1:]) for row in rows]
 
@@ -295,9 +292,8 @@ class PushRulesStream(Stream):
         push_rules_token, _ = self.store.get_push_rules_stream_token()
         return push_rules_token
 
-    @defer.inlineCallbacks
-    def update_function(self, from_token, to_token, limit):
-        rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit)
+    async def update_function(self, from_token, to_token, limit):
+        rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
         return [(row[0], row[2]) for row in rows]
 
 
@@ -413,9 +409,8 @@ class AccountDataStream(Stream):
 
         super(AccountDataStream, self).__init__(hs)
 
-    @defer.inlineCallbacks
-    def update_function(self, from_token, to_token, limit):
-        global_results, room_results = yield self.store.get_all_updated_account_data(
+    async def update_function(self, from_token, to_token, limit):
+        global_results, room_results = await self.store.get_all_updated_account_data(
             from_token, from_token, to_token, limit
         )
 
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 0843e5aa90..b3afabb8cd 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,8 +19,6 @@ from typing import Tuple, Type
 
 import attr
 
-from twisted.internet import defer
-
 from ._base import Stream
 
 
@@ -122,16 +120,15 @@ class EventsStream(Stream):
 
         super(EventsStream, self).__init__(hs)
 
-    @defer.inlineCallbacks
-    def update_function(self, from_token, current_token, limit=None):
-        event_rows = yield self._store.get_all_new_forward_event_rows(
+    async def update_function(self, from_token, current_token, limit=None):
+        event_rows = await self._store.get_all_new_forward_event_rows(
             from_token, current_token, limit
         )
         event_updates = (
             (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows
         )
 
-        state_rows = yield self._store.get_all_updated_current_state_deltas(
+        state_rows = await self._store.get_all_updated_current_state_deltas(
             from_token, current_token, limit
         )
         state_updates = (
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 7a256b6ecb..50e080673b 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -206,10 +206,6 @@ class AuthRestServlet(RestServlet):
 
             return None
         elif stagetype == LoginType.TERMS:
-            if ("session" not in request.args or len(request.args["session"])) == 0:
-                raise SynapseError(400, "No session supplied")
-
-            session = request.args["session"][0]
             authdict = {"session": session}
 
             success = await self.auth_handler.add_oob_auth(
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 2a477ad22e..3d0fefb4df 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -71,6 +71,8 @@ class VersionsRestServlet(RestServlet):
                     # Implements support for label-based filtering as described in
                     # MSC2326.
                     "org.matrix.label_based_filtering": True,
+                    # Implements support for cross signing as described in MSC1756
+                    "org.matrix.e2e_cross_signing": True,
                 },
             },
         )
diff --git a/synapse/server.pyi b/synapse/server.pyi
index b5e0b57095..0731403047 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -1,3 +1,5 @@
+import twisted.internet
+
 import synapse.api.auth
 import synapse.config.homeserver
 import synapse.federation.sender
@@ -9,10 +11,12 @@ import synapse.handlers.deactivate_account
 import synapse.handlers.device
 import synapse.handlers.e2e_keys
 import synapse.handlers.message
+import synapse.handlers.presence
 import synapse.handlers.room
 import synapse.handlers.room_member
 import synapse.handlers.set_password
 import synapse.http.client
+import synapse.notifier
 import synapse.rest.media.v1.media_repository
 import synapse.server_notices.server_notices_manager
 import synapse.server_notices.server_notices_sender
@@ -85,3 +89,11 @@ class HomeServer(object):
         self,
     ) -> synapse.server_notices.server_notices_sender.ServerNoticesSender:
         pass
+    def get_notifier(self) -> synapse.notifier.Notifier:
+        pass
+    def get_presence_handler(self) -> synapse.handlers.presence.PresenceHandler:
+        pass
+    def get_clock(self) -> synapse.util.Clock:
+        pass
+    def get_reactor(self) -> twisted.internet.base.ReactorBase:
+        pass
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 2dac90578c..f7432c8d2f 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -105,7 +105,7 @@ class ServerNoticesManager(object):
 
         assert self._is_mine_id(user_id), "Cannot send server notices to remote users"
 
-        rooms = yield self._store.get_rooms_for_user_where_membership_is(
+        rooms = yield self._store.get_rooms_for_local_user_where_membership_is(
             user_id, [Membership.INVITE, Membership.JOIN]
         )
         system_mxid = self._config.server_notices_mxid
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 5accc071ab..cacd0c0c2b 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,7 +16,7 @@
 
 import logging
 from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import Dict, Iterable, List, Optional
 
 from six import iteritems, itervalues
 
@@ -33,6 +33,7 @@ from synapse.events.snapshot import EventContext
 from synapse.logging.utils import log_function
 from synapse.state import v1, v2
 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
+from synapse.types import StateMap
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import get_cache_factor_for
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -594,7 +595,7 @@ def _make_state_cache_entry(new_state, state_groups_ids):
 def resolve_events_with_store(
     room_id: str,
     room_version: str,
-    state_sets: List[Dict[Tuple[str, str], str]],
+    state_sets: List[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
     state_res_store: "StateResolutionStore",
 ):
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index b2f9865f39..d6c34ce3b7 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -15,7 +15,7 @@
 
 import hashlib
 import logging
-from typing import Callable, Dict, List, Optional, Tuple
+from typing import Callable, Dict, List, Optional
 
 from six import iteritems, iterkeys, itervalues
 
@@ -26,6 +26,7 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase
+from synapse.types import StateMap
 
 logger = logging.getLogger(__name__)
 
@@ -36,7 +37,7 @@ POWER_KEY = (EventTypes.PowerLevels, "")
 @defer.inlineCallbacks
 def resolve_events_with_store(
     room_id: str,
-    state_sets: List[Dict[Tuple[str, str], str]],
+    state_sets: List[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
     state_map_factory: Callable,
 ):
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 72fb8a6317..6216fdd204 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -16,7 +16,7 @@
 import heapq
 import itertools
 import logging
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional
 
 from six import iteritems, itervalues
 
@@ -27,6 +27,7 @@ from synapse import event_auth
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.events import EventBase
+from synapse.types import StateMap
 
 logger = logging.getLogger(__name__)
 
@@ -35,7 +36,7 @@ logger = logging.getLogger(__name__)
 def resolve_events_with_store(
     room_id: str,
     room_version: str,
-    state_sets: List[Dict[Tuple[str, str], str]],
+    state_sets: List[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
     state_res_store: "synapse.state.StateResolutionStore",
 ):
@@ -393,12 +394,12 @@ def _iterative_auth_checks(
         room_id (str)
         room_version (str)
         event_ids (list[str]): Ordered list of events to apply auth checks to
-        base_state (dict[tuple[str, str], str]): The set of state to start with
+        base_state (StateMap[str]): The set of state to start with
         event_map (dict[str,FrozenEvent])
         state_res_store (StateResolutionStore)
 
     Returns:
-        Deferred[dict[tuple[str, str], str]]: Returns the final updated state
+        Deferred[StateMap[str]]: Returns the final updated state
     """
     resolved_state = base_state.copy()
 
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index 54ed8574c4..bf91512daf 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -21,7 +21,7 @@ from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.engines import PostgresEngine
-from synapse.util import batch_iter
+from synapse.util.iterutils import batch_iter
 
 logger = logging.getLogger(__name__)
 
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 9a828231c4..f0a7962dd0 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -33,13 +33,13 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import Database
 from synapse.types import get_verify_key_from_cross_signing_key
-from synapse.util import batch_iter
 from synapse.util.caches.descriptors import (
     Cache,
     cached,
     cachedInlineCallbacks,
     cachedList,
 )
+from synapse.util.iterutils import batch_iter
 
 logger = logging.getLogger(__name__)
 
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 58f35d7f56..bb69c20448 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -43,9 +43,9 @@ from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.data_stores.main.state import StateGroupWorkerStore
 from synapse.storage.database import Database
 from synapse.types import RoomStreamToken, get_domain_from_id
-from synapse.util import batch_iter
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 from synapse.util.frozenutils import frozendict_json_encoder
+from synapse.util.iterutils import batch_iter
 
 logger = logging.getLogger(__name__)
 
@@ -128,6 +128,7 @@ class EventsStore(
             hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
 
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+        self.is_mine_id = hs.is_mine_id
 
     @defer.inlineCallbacks
     def _read_forward_extremities(self):
@@ -547,6 +548,34 @@ class EventsStore(
                 ],
             )
 
+            # Note: Do we really want to delete rows here (that we do not
+            # subsequently reinsert below)? While technically correct it means
+            # we have no record of the fact the user *was* a member of the
+            # room but got, say, state reset out of it.
+            if to_delete or to_insert:
+                txn.executemany(
+                    "DELETE FROM local_current_membership"
+                    " WHERE room_id = ? AND user_id = ?",
+                    (
+                        (room_id, state_key)
+                        for etype, state_key in itertools.chain(to_delete, to_insert)
+                        if etype == EventTypes.Member and self.is_mine_id(state_key)
+                    ),
+                )
+
+            if to_insert:
+                txn.executemany(
+                    """INSERT INTO local_current_membership
+                        (room_id, user_id, event_id, membership)
+                    VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
+                    """,
+                    [
+                        (room_id, key[1], ev_id, ev_id)
+                        for key, ev_id in to_insert.items()
+                        if key[0] == EventTypes.Member and self.is_mine_id(key[1])
+                    ],
+                )
+
             txn.call_after(
                 self._curr_state_delta_stream_cache.entity_has_changed,
                 room_id,
@@ -1724,6 +1753,7 @@ class EventsStore(
             "local_invites",
             "room_account_data",
             "room_tags",
+            "local_current_membership",
         ):
             logger.info("[purge] removing %s from %s", room_id, table)
             txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 0cce5232f5..3b93e0597a 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -37,8 +37,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.database import Database
 from synapse.types import get_domain_from_id
-from synapse.util import batch_iter
 from synapse.util.caches.descriptors import Cache
+from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
 logger = logging.getLogger(__name__)
diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py
index 6b12f5a75f..ba89c68c9f 100644
--- a/synapse/storage/data_stores/main/keys.py
+++ b/synapse/storage/data_stores/main/keys.py
@@ -23,8 +23,8 @@ from signedjson.key import decode_verify_key_bytes
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.keys import FetchKeyResult
-from synapse.util import batch_iter
 from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
 
 logger = logging.getLogger(__name__)
 
diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py
index a2c83e0867..604c8b7ddd 100644
--- a/synapse/storage/data_stores/main/presence.py
+++ b/synapse/storage/data_stores/main/presence.py
@@ -17,8 +17,8 @@ from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.presence import UserPresenceState
-from synapse.util import batch_iter
 from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
 
 
 class PresenceStore(SQLBaseStore):
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 70ff5751b6..9acef7c950 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -297,19 +297,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return {row[0]: row[1] for row in txn}
 
     @cached()
-    def get_invited_rooms_for_user(self, user_id):
-        """ Get all the rooms the user is invited to
+    def get_invited_rooms_for_local_user(self, user_id):
+        """ Get all the rooms the *local* user is invited to
+
         Args:
             user_id (str): The user ID.
         Returns:
             A deferred list of RoomsForUser.
         """
 
-        return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE])
+        return self.get_rooms_for_local_user_where_membership_is(
+            user_id, [Membership.INVITE]
+        )
 
     @defer.inlineCallbacks
-    def get_invite_for_user_in_room(self, user_id, room_id):
-        """Gets the invite for the given user and room
+    def get_invite_for_local_user_in_room(self, user_id, room_id):
+        """Gets the invite for the given *local* user and room
 
         Args:
             user_id (str)
@@ -319,15 +322,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             Deferred: Resolves to either a RoomsForUser or None if no invite was
                 found.
         """
-        invites = yield self.get_invited_rooms_for_user(user_id)
+        invites = yield self.get_invited_rooms_for_local_user(user_id)
         for invite in invites:
             if invite.room_id == room_id:
                 return invite
         return None
 
     @defer.inlineCallbacks
-    def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
-        """ Get all the rooms for this user where the membership for this user
+    def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list):
+        """ Get all the rooms for this *local* user where the membership for this user
         matches one in the membership list.
 
         Filters out forgotten rooms.
@@ -344,8 +347,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             return defer.succeed(None)
 
         rooms = yield self.db.runInteraction(
-            "get_rooms_for_user_where_membership_is",
-            self._get_rooms_for_user_where_membership_is_txn,
+            "get_rooms_for_local_user_where_membership_is",
+            self._get_rooms_for_local_user_where_membership_is_txn,
             user_id,
             membership_list,
         )
@@ -354,76 +357,42 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
         return [room for room in rooms if room.room_id not in forgotten_rooms]
 
-    def _get_rooms_for_user_where_membership_is_txn(
+    def _get_rooms_for_local_user_where_membership_is_txn(
         self, txn, user_id, membership_list
     ):
+        # Paranoia check.
+        if not self.hs.is_mine_id(user_id):
+            raise Exception(
+                "Cannot call 'get_rooms_for_local_user_where_membership_is' on non-local user %r"
+                % (user_id,),
+            )
 
-        do_invite = Membership.INVITE in membership_list
-        membership_list = [m for m in membership_list if m != Membership.INVITE]
-
-        results = []
-        if membership_list:
-            if self._current_state_events_membership_up_to_date:
-                clause, args = make_in_list_sql_clause(
-                    self.database_engine, "c.membership", membership_list
-                )
-                sql = """
-                    SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
-                    FROM current_state_events AS c
-                    INNER JOIN events AS e USING (room_id, event_id)
-                    WHERE
-                        c.type = 'm.room.member'
-                        AND state_key = ?
-                        AND %s
-                """ % (
-                    clause,
-                )
-            else:
-                clause, args = make_in_list_sql_clause(
-                    self.database_engine, "m.membership", membership_list
-                )
-                sql = """
-                    SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering
-                    FROM current_state_events AS c
-                    INNER JOIN room_memberships AS m USING (room_id, event_id)
-                    INNER JOIN events AS e USING (room_id, event_id)
-                    WHERE
-                        c.type = 'm.room.member'
-                        AND state_key = ?
-                        AND %s
-                """ % (
-                    clause,
-                )
-
-            txn.execute(sql, (user_id, *args))
-            results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
+        clause, args = make_in_list_sql_clause(
+            self.database_engine, "c.membership", membership_list
+        )
 
-        if do_invite:
-            sql = (
-                "SELECT i.room_id, inviter, i.event_id, e.stream_ordering"
-                " FROM local_invites as i"
-                " INNER JOIN events as e USING (event_id)"
-                " WHERE invitee = ? AND locally_rejected is NULL"
-                " AND replaced_by is NULL"
-            )
+        sql = """
+            SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
+            FROM local_current_membership AS c
+            INNER JOIN events AS e USING (room_id, event_id)
+            WHERE
+                user_id = ?
+                AND %s
+        """ % (
+            clause,
+        )
 
-            txn.execute(sql, (user_id,))
-            results.extend(
-                RoomsForUser(
-                    room_id=r["room_id"],
-                    sender=r["inviter"],
-                    event_id=r["event_id"],
-                    stream_ordering=r["stream_ordering"],
-                    membership=Membership.INVITE,
-                )
-                for r in self.db.cursor_to_dict(txn)
-            )
+        txn.execute(sql, (user_id, *args))
+        results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
 
         return results
 
-    @cachedInlineCallbacks(max_entries=500000, iterable=True)
+    @cached(max_entries=500000, iterable=True)
     def get_rooms_for_user_with_stream_ordering(self, user_id):
-        """Returns a set of room_ids the user is currently joined to
+        """Returns a set of room_ids the user is currently joined to.
+
+        If a remote user only returns rooms this server is currently
+        participating in.
 
         Args:
             user_id (str)
@@ -433,17 +402,49 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             the rooms the user is in currently, along with the stream ordering
             of the most recent join for that user and room.
         """
-        rooms = yield self.get_rooms_for_user_where_membership_is(
-            user_id, membership_list=[Membership.JOIN]
-        )
-        return frozenset(
-            GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
-            for r in rooms
+        return self.db.runInteraction(
+            "get_rooms_for_user_with_stream_ordering",
+            self._get_rooms_for_user_with_stream_ordering_txn,
+            user_id,
         )
 
+    def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id):
+        # We use `current_state_events` here and not `local_current_membership`
+        # as a) this gets called with remote users and b) this only gets called
+        # for rooms the server is participating in.
+        if self._current_state_events_membership_up_to_date:
+            sql = """
+                SELECT room_id, e.stream_ordering
+                FROM current_state_events AS c
+                INNER JOIN events AS e USING (room_id, event_id)
+                WHERE
+                    c.type = 'm.room.member'
+                    AND state_key = ?
+                    AND c.membership = ?
+            """
+        else:
+            sql = """
+                SELECT room_id, e.stream_ordering
+                FROM current_state_events AS c
+                INNER JOIN room_memberships AS m USING (room_id, event_id)
+                INNER JOIN events AS e USING (room_id, event_id)
+                WHERE
+                    c.type = 'm.room.member'
+                    AND state_key = ?
+                    AND m.membership = ?
+            """
+
+        txn.execute(sql, (user_id, Membership.JOIN))
+        results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
+
+        return results
+
     @defer.inlineCallbacks
     def get_rooms_for_user(self, user_id, on_invalidate=None):
-        """Returns a set of room_ids the user is currently joined to
+        """Returns a set of room_ids the user is currently joined to.
+
+        If a remote user only returns rooms this server is currently
+        participating in.
         """
         rooms = yield self.get_rooms_for_user_with_stream_ordering(
             user_id, on_invalidate=on_invalidate
@@ -1022,7 +1023,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
                 event.internal_metadata.stream_ordering,
             )
             txn.call_after(
-                self.get_invited_rooms_for_user.invalidate, (event.state_key,)
+                self.get_invited_rooms_for_local_user.invalidate, (event.state_key,)
             )
 
             # We update the local_invites table only if the event is "current",
@@ -1064,6 +1065,27 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
                         ),
                     )
 
+                # We also update the `local_current_membership` table with
+                # latest invite info. This will usually get updated by the
+                # `current_state_events` handling, unless its an outlier.
+                if event.internal_metadata.is_outlier():
+                    # This should only happen for out of band memberships, so
+                    # we add a paranoia check.
+                    assert event.internal_metadata.is_out_of_band_membership()
+
+                    self.db.simple_upsert_txn(
+                        txn,
+                        table="local_current_membership",
+                        keyvalues={
+                            "room_id": event.room_id,
+                            "user_id": event.state_key,
+                        },
+                        values={
+                            "event_id": event.event_id,
+                            "membership": event.membership,
+                        },
+                    )
+
     @defer.inlineCallbacks
     def locally_reject_invite(self, user_id, room_id):
         sql = (
@@ -1075,6 +1097,15 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
         def f(txn, stream_ordering):
             txn.execute(sql, (stream_ordering, True, room_id, user_id))
 
+            # We also clear this entry from `local_current_membership`.
+            # Ideally we'd point to a leave event, but we don't have one, so
+            # nevermind.
+            self.db.simple_delete_txn(
+                txn,
+                table="local_current_membership",
+                keyvalues={"room_id": room_id, "user_id": user_id},
+            )
+
         with self._stream_id_gen.get_next() as stream_ordering:
             yield self.db.runInteraction("locally_reject_invite", f, stream_ordering)
 
diff --git a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py
new file mode 100644
index 0000000000..601c236c4a
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# We create a new table called `local_current_membership` that stores the latest
+# membership state of local users in rooms, which helps track leaves/bans/etc
+# even if the server has left the room (and so has deleted the room from
+# `current_state_events`). This will also include outstanding invites for local
+# users for rooms the server isn't in.
+#
+# If the server isn't and hasn't been in the room then it will only include
+# outsstanding invites, and not e.g. pre-emptive bans of local users.
+#
+# If the server later rejoins a room `local_current_membership` can simply be
+# replaced with the new current state of the room (which results in the
+# equivalent behaviour as if the server had remained in the room).
+
+
+def run_upgrade(cur, database_engine, config, *args, **kwargs):
+    # We need to do the insert in `run_upgrade` section as we don't have access
+    # to `config` in `run_create`.
+
+    # This upgrade may take a bit of time for large servers (e.g. one minute for
+    # matrix.org) but means we avoid a lots of book keeping required to do it as
+    # a background update.
+
+    # We check if the `current_state_events.membership` is up to date by
+    # checking if the relevant background update has finished. If it has
+    # finished we can avoid doing a join against `room_memberships`, which
+    # speesd things up.
+    cur.execute(
+        """SELECT 1 FROM background_updates
+            WHERE update_name = 'current_state_events_membership'
+        """
+    )
+    current_state_membership_up_to_date = not bool(cur.fetchone())
+
+    # Cheekily drop and recreate indices, as that is faster.
+    cur.execute("DROP INDEX local_current_membership_idx")
+    cur.execute("DROP INDEX local_current_membership_room_idx")
+
+    if current_state_membership_up_to_date:
+        sql = """
+            INSERT INTO local_current_membership (room_id, user_id, event_id, membership)
+                SELECT c.room_id, state_key AS user_id, event_id, c.membership
+                FROM current_state_events AS c
+                WHERE type = 'm.room.member' AND c.membership IS NOT NULL AND state_key like '%' || ?
+        """
+    else:
+        # We can't rely on the membership column, so we need to join against
+        # `room_memberships`.
+        sql = """
+            INSERT INTO local_current_membership (room_id, user_id, event_id, membership)
+                SELECT c.room_id, state_key AS user_id, event_id, r.membership
+                FROM current_state_events AS c
+                INNER JOIN room_memberships AS r USING (event_id)
+                WHERE type = 'm.room.member' and state_key like '%' || ?
+        """
+    cur.execute(sql, (config.server_name,))
+
+    cur.execute(
+        "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"
+    )
+    cur.execute(
+        "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)"
+    )
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+    cur.execute(
+        """
+        CREATE TABLE local_current_membership (
+            room_id TEXT NOT NULL,
+            user_id TEXT NOT NULL,
+            event_id TEXT NOT NULL,
+            membership TEXT NOT NULL
+        )"""
+    )
+
+    cur.execute(
+        "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"
+    )
+    cur.execute(
+        "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)"
+    )
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index d07440e3ed..33bebd1c48 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -165,19 +165,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
     # FIXME: how should this be cached?
-    def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
+    def get_filtered_current_state_ids(
+        self, room_id: str, state_filter: StateFilter = StateFilter.all()
+    ):
         """Get the current state event of a given type for a room based on the
         current_state_events table.  This may not be as up-to-date as the result
         of doing a fresh state resolution as per state_handler.get_current_state
 
         Args:
-            room_id (str)
-            state_filter (StateFilter): The state filter used to fetch state
+            room_id
+            state_filter: The state filter used to fetch state
                 from the database.
 
         Returns:
-            Deferred[dict[tuple[str, str], str]]: Map from type/state_key to
-            event ID.
+            defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
         """
 
         where_clause, where_args = state_filter.make_sql_filter_clause()
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index d53695f238..c4ee9b7ccb 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/data_stores/state/store.py
@@ -15,6 +15,7 @@
 
 import logging
 from collections import namedtuple
+from typing import Dict, Iterable, List, Set, Tuple
 
 from six import iteritems
 from six.moves import range
@@ -26,6 +27,7 @@ from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
 from synapse.storage.database import Database
 from synapse.storage.state import StateFilter
+from synapse.types import StateMap
 from synapse.util.caches import get_cache_factor_for
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.dictionary_cache import DictionaryCache
@@ -133,17 +135,18 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def _get_state_groups_from_groups(self, groups, state_filter):
-        """Returns the state groups for a given set of groups, filtering on
-        types of state events.
+    def _get_state_groups_from_groups(
+        self, groups: List[int], state_filter: StateFilter
+    ):
+        """Returns the state groups for a given set of groups from the
+        database, filtering on types of state events.
 
         Args:
-            groups(list[int]): list of state group IDs to query
-            state_filter (StateFilter): The state filter used to fetch state
+            groups: list of state group IDs to query
+            state_filter: The state filter used to fetch state
                 from the database.
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
+            Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
         """
         results = {}
 
@@ -199,18 +202,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         return state_filter.filter_state(state_dict_ids), not missing_types
 
     @defer.inlineCallbacks
-    def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+    def _get_state_for_groups(
+        self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+    ):
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key
 
         Args:
-            groups (iterable[int]): list of state groups for which we want
+            groups: list of state groups for which we want
                 to get the state.
-            state_filter (StateFilter): The state filter used to fetch state
+            state_filter: The state filter used to fetch state
                 from the database.
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
+            Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
         """
 
         member_filter, non_member_filter = state_filter.get_member_split()
@@ -268,24 +272,24 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
         return state
 
-    def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
+    def _get_state_for_groups_using_cache(
+        self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter
+    ) -> Tuple[Dict[int, StateMap[str]], Set[int]]:
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key, querying from a specific cache.
 
         Args:
-            groups (iterable[int]): list of state groups for which we want
-                to get the state.
-            cache (DictionaryCache): the cache of group ids to state dicts which
-                we will pass through - either the normal state cache or the specific
-                members state cache.
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
+            groups: list of state groups for which we want to get the state.
+            cache: the cache of group ids to state dicts which
+                we will pass through - either the normal state cache or the
+                specific members state cache.
+            state_filter: The state filter used to fetch state from the
+                database.
 
         Returns:
-            tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
-            dict of state_group_id -> (dict of (type, state_key) -> event id)
-            of entries in the cache, and the state group ids either missing
-            from the cache or incomplete.
+            Tuple of dict of state_group_id to state map of entries in the
+            cache, and the state group ids either missing from the cache or
+            incomplete.
         """
         results = {}
         incomplete_groups = set()
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index e70026b80a..e86984cd50 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 56
+SCHEMA_VERSION = 57
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index d6a7bd7834..fdc0abf5cf 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -34,7 +34,7 @@ class PurgeEventsStorage(object):
         """
 
         state_groups_to_delete = yield self.stores.main.purge_room(room_id)
-        yield self.stores.main.purge_room_state(room_id, state_groups_to_delete)
+        yield self.stores.state.purge_room_state(room_id, state_groups_to_delete)
 
     @defer.inlineCallbacks
     def purge_history(self, room_id, token, delete_local_events):
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index cbeb586014..c522c80922 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import Iterable, List, TypeVar
 
 from six import iteritems, itervalues
 
@@ -22,9 +23,13 @@ import attr
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes
+from synapse.types import StateMap
 
 logger = logging.getLogger(__name__)
 
+# Used for generic functions below
+T = TypeVar("T")
+
 
 @attr.s(slots=True)
 class StateFilter(object):
@@ -233,14 +238,14 @@ class StateFilter(object):
 
         return len(self.concrete_types())
 
-    def filter_state(self, state_dict):
+    def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
         """Returns the state filtered with by this StateFilter
 
         Args:
-            state (dict[tuple[str, str], Any]): The state map to filter
+            state: The state map to filter
 
         Returns:
-            dict[tuple[str, str], Any]: The filtered state map
+            The filtered state map
         """
         if self.is_full():
             return dict(state_dict)
@@ -333,12 +338,12 @@ class StateGroupStorage(object):
     def __init__(self, hs, stores):
         self.stores = stores
 
-    def get_state_group_delta(self, state_group):
+    def get_state_group_delta(self, state_group: int):
         """Given a state group try to return a previous group and a delta between
         the old and the new.
 
         Returns:
-            Deferred[Tuple[Optional[int], Optional[list[dict[tuple[str, str], str]]]]]):
+            Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
                 (prev_group, delta_ids)
         """
 
@@ -353,7 +358,7 @@ class StateGroupStorage(object):
             event_ids (iterable[str]): ids of the events
 
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
+            Deferred[dict[int, StateMap[str]]]:
                 dict of state_group_id -> (dict of (type, state_key) -> event id)
         """
         if not event_ids:
@@ -410,17 +415,18 @@ class StateGroupStorage(object):
             for group, event_id_map in iteritems(group_to_ids)
         }
 
-    def _get_state_groups_from_groups(self, groups, state_filter):
+    def _get_state_groups_from_groups(
+        self, groups: List[int], state_filter: StateFilter
+    ):
         """Returns the state groups for a given set of groups, filtering on
         types of state events.
 
         Args:
-            groups(list[int]): list of state group IDs to query
-            state_filter (StateFilter): The state filter used to fetch state
+            groups: list of state group IDs to query
+            state_filter: The state filter used to fetch state
                 from the database.
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
+            Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
         """
 
         return self.stores.state._get_state_groups_from_groups(groups, state_filter)
@@ -519,7 +525,9 @@ class StateGroupStorage(object):
         state_map = yield self.get_state_ids_for_events([event_id], state_filter)
         return state_map[event_id]
 
-    def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+    def _get_state_for_groups(
+        self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+    ):
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key
 
@@ -529,8 +537,7 @@ class StateGroupStorage(object):
             state_filter (StateFilter): The state filter used to fetch state
                 from the database.
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
+            Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
         """
         return self.stores.state._get_state_for_groups(groups, state_filter)
 
diff --git a/synapse/types.py b/synapse/types.py
index cd996c0b5a..65e4d8c181 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -17,6 +17,7 @@ import re
 import string
 import sys
 from collections import namedtuple
+from typing import Dict, Tuple, TypeVar
 
 import attr
 from signedjson.key import decode_verify_key_bytes
@@ -28,7 +29,7 @@ from synapse.api.errors import SynapseError
 if sys.version_info[:3] >= (3, 6, 0):
     from typing import Collection
 else:
-    from typing import Sized, Iterable, Container, TypeVar
+    from typing import Sized, Iterable, Container
 
     T_co = TypeVar("T_co", covariant=True)
 
@@ -36,6 +37,12 @@ else:
         __slots__ = ()
 
 
+# Define a state map type from type/state_key to T (usually an event ID or
+# event)
+T = TypeVar("T")
+StateMap = Dict[Tuple[str, str], T]
+
+
 class Requester(
     namedtuple(
         "Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"]
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 7856353002..60f0de70f7 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -15,7 +15,6 @@
 
 import logging
 import re
-from itertools import islice
 
 import attr
 
@@ -107,22 +106,6 @@ class Clock(object):
                 raise
 
 
-def batch_iter(iterable, size):
-    """batch an iterable up into tuples with a maximum size
-
-    Args:
-        iterable (iterable): the iterable to slice
-        size (int): the maximum batch size
-
-    Returns:
-        an iterator over the chunks
-    """
-    # make sure we can deal with iterables like lists too
-    sourceiter = iter(iterable)
-    # call islice until it returns an empty tuple
-    return iter(lambda: tuple(islice(sourceiter, size)), ())
-
-
 def log_failure(failure, msg, consumeErrors=True):
     """Creates a function suitable for passing to `Deferred.addErrback` that
     logs any failures that occur.
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
new file mode 100644
index 0000000000..06faeebe7f
--- /dev/null
+++ b/synapse/util/iterutils.py
@@ -0,0 +1,48 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from itertools import islice
+from typing import Iterable, Iterator, Sequence, Tuple, TypeVar
+
+T = TypeVar("T")
+
+
+def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]:
+    """batch an iterable up into tuples with a maximum size
+
+    Args:
+        iterable (iterable): the iterable to slice
+        size (int): the maximum batch size
+
+    Returns:
+        an iterator over the chunks
+    """
+    # make sure we can deal with iterables like lists too
+    sourceiter = iter(iterable)
+    # call islice until it returns an empty tuple
+    return iter(lambda: tuple(islice(sourceiter, size)), ())
+
+
+ISeq = TypeVar("ISeq", bound=Sequence, covariant=True)
+
+
+def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]:
+    """Split the given sequence into chunks of the given size
+
+    The last chunk may be shorter than the given size.
+
+    If the input is empty, no chunks are returned.
+    """
+    return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen))
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 2705cbe5f8..bb62db4637 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -34,7 +34,7 @@ def load_module(provider):
     provider_class = getattr(module, clz)
 
     try:
-        provider_config = provider_class.parse_config(provider["config"])
+        provider_config = provider_class.parse_config(provider.get("config"))
     except Exception as e:
         raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e))