summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py6
-rw-r--r--synapse/api/errors.py43
-rw-r--r--synapse/app/_base.py58
-rw-r--r--synapse/app/appservice.py51
-rw-r--r--synapse/app/client_reader.py52
-rw-r--r--synapse/app/federation_reader.py51
-rw-r--r--synapse/app/federation_sender.py51
-rw-r--r--synapse/app/frontend_proxy.py51
-rwxr-xr-xsynapse/app/homeserver.py99
-rw-r--r--synapse/app/media_repository.py61
-rw-r--r--synapse/app/pusher.py51
-rw-r--r--synapse/app/synchrotron.py54
-rwxr-xr-xsynapse/app/synctl.py31
-rw-r--r--synapse/app/user_dir.py51
-rw-r--r--synapse/config/homeserver.py3
-rw-r--r--synapse/config/logger.py50
-rw-r--r--synapse/config/password_auth_providers.py4
-rw-r--r--synapse/config/registration.py19
-rw-r--r--synapse/config/repository.py87
-rw-r--r--synapse/config/server.py41
-rw-r--r--synapse/config/tls.py2
-rw-r--r--synapse/config/user_directory.py44
-rw-r--r--synapse/config/workers.py5
-rw-r--r--synapse/crypto/event_signing.py13
-rw-r--r--synapse/event_auth.py2
-rw-r--r--synapse/events/snapshot.py4
-rw-r--r--synapse/federation/federation_base.py27
-rw-r--r--synapse/federation/federation_client.py60
-rw-r--r--synapse/federation/federation_server.py46
-rw-r--r--synapse/federation/transaction_queue.py18
-rw-r--r--synapse/federation/transport/client.py3
-rw-r--r--synapse/federation/transport/server.py9
-rw-r--r--synapse/handlers/appservice.py7
-rw-r--r--synapse/handlers/auth.py236
-rw-r--r--synapse/handlers/deactivate_account.py52
-rw-r--r--synapse/handlers/device.py24
-rw-r--r--synapse/handlers/devicemessage.py14
-rw-r--r--synapse/handlers/directory.py7
-rw-r--r--synapse/handlers/e2e_keys.py21
-rw-r--r--synapse/handlers/federation.py58
-rw-r--r--synapse/handlers/groups_local.py7
-rw-r--r--synapse/handlers/message.py321
-rw-r--r--synapse/handlers/profile.py14
-rw-r--r--synapse/handlers/register.py18
-rw-r--r--synapse/handlers/room.py21
-rw-r--r--synapse/handlers/room_list.py3
-rw-r--r--synapse/handlers/room_member.py25
-rw-r--r--synapse/handlers/set_password.py56
-rw-r--r--synapse/handlers/user_directory.py72
-rw-r--r--synapse/http/client.py14
-rw-r--r--synapse/http/endpoint.py9
-rw-r--r--synapse/http/matrixfederationclient.py28
-rw-r--r--synapse/http/server.py127
-rw-r--r--synapse/http/servlet.py18
-rw-r--r--synapse/http/site.py10
-rw-r--r--synapse/metrics/__init__.py7
-rw-r--r--synapse/metrics/metric.py112
-rw-r--r--synapse/module_api/__init__.py14
-rw-r--r--synapse/notifier.py17
-rw-r--r--synapse/push/httppusher.py37
-rw-r--r--synapse/push/pusherpool.py24
-rw-r--r--synapse/replication/slave/storage/events.py4
-rw-r--r--synapse/replication/tcp/protocol.py19
-rw-r--r--synapse/replication/tcp/resource.py6
-rw-r--r--synapse/rest/client/v1/admin.py46
-rw-r--r--synapse/rest/client/v1/login.py14
-rw-r--r--synapse/rest/client/v1/logout.py27
-rw-r--r--synapse/rest/client/v1/register.py34
-rw-r--r--synapse/rest/client/v1/room.py63
-rw-r--r--synapse/rest/client/v2_alpha/_base.py41
-rw-r--r--synapse/rest/client/v2_alpha/account.py146
-rw-r--r--synapse/rest/client/v2_alpha/devices.py36
-rw-r--r--synapse/rest/client/v2_alpha/groups.py22
-rw-r--r--synapse/rest/client/v2_alpha/register.py76
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py8
-rw-r--r--synapse/rest/media/v1/_base.py144
-rw-r--r--synapse/rest/media/v1/download_resource.py75
-rw-r--r--synapse/rest/media/v1/media_repository.py592
-rw-r--r--synapse/rest/media/v1/media_storage.py274
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py93
-rw-r--r--synapse/rest/media/v1/storage_provider.py140
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py220
-rw-r--r--synapse/server.py84
-rw-r--r--synapse/server.pyi21
-rw-r--r--synapse/state.py308
-rw-r--r--synapse/storage/__init__.py1
-rw-r--r--synapse/storage/_base.py80
-rw-r--r--synapse/storage/account_data.py85
-rw-r--r--synapse/storage/background_updates.py35
-rw-r--r--synapse/storage/engines/postgres.py6
-rw-r--r--synapse/storage/engines/sqlite3.py19
-rw-r--r--synapse/storage/events.py209
-rw-r--r--synapse/storage/media_repository.py40
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/profile.py27
-rw-r--r--synapse/storage/registration.py10
-rw-r--r--synapse/storage/room.py196
-rw-r--r--synapse/storage/schema/delta/44/expire_url_cache.sql5
-rw-r--r--synapse/storage/schema/delta/46/local_media_repository_url_idx.sql24
-rw-r--r--synapse/storage/schema/delta/46/user_dir_null_room_ids.sql35
-rw-r--r--synapse/storage/schema/delta/47/last_access_media.sql16
-rw-r--r--synapse/storage/schema/delta/47/state_group_seq.py37
-rw-r--r--synapse/storage/search.py110
-rw-r--r--synapse/storage/state.py196
-rw-r--r--synapse/storage/user_directory.py57
-rw-r--r--synapse/util/caches/descriptors.py4
-rw-r--r--synapse/util/caches/expiringcache.py6
-rw-r--r--synapse/util/caches/lrucache.py28
-rw-r--r--synapse/util/file_consumer.py139
-rw-r--r--synapse/util/logcontext.py48
-rw-r--r--synapse/util/metrics.py75
-rw-r--r--synapse/util/retryutils.py12
-rw-r--r--synapse/util/threepids.py48
114 files changed, 4491 insertions, 2074 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 8c3d7a210a..ef8853bd24 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
 """ This is a reference implementation of a Matrix home server.
 """
 
-__version__ = "0.25.1"
+__version__ = "0.26.0"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 72858cca1f..ac0a3655a5 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -270,7 +270,11 @@ class Auth(object):
             rights (str): The operation being performed; the access token must
                 allow this.
         Returns:
-            dict : dict that includes the user and the ID of their access token.
+            Deferred[dict]: dict that includes:
+               `user` (UserID)
+               `is_guest` (bool)
+               `token_id` (int|None): access token id. May be None if guest
+               `device_id` (str|None): device corresponding to access token
         Raises:
             AuthError if no user by that token exists or the token is invalid.
         """
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index d0dfa959dc..aa15f73f36 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -46,6 +46,7 @@ class Codes(object):
     THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
     THREEPID_IN_USE = "M_THREEPID_IN_USE"
     THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
+    THREEPID_DENIED = "M_THREEPID_DENIED"
     INVALID_USERNAME = "M_INVALID_USERNAME"
     SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
 
@@ -140,6 +141,48 @@ class RegistrationError(SynapseError):
     pass
 
 
+class FederationDeniedError(SynapseError):
+    """An error raised when the server tries to federate with a server which
+    is not on its federation whitelist.
+
+    Attributes:
+        destination (str): The destination which has been denied
+    """
+
+    def __init__(self, destination):
+        """Raised by federation client or server to indicate that we are
+        are deliberately not attempting to contact a given server because it is
+        not on our federation whitelist.
+
+        Args:
+            destination (str): the domain in question
+        """
+
+        self.destination = destination
+
+        super(FederationDeniedError, self).__init__(
+            code=403,
+            msg="Federation denied with %s." % (self.destination,),
+            errcode=Codes.FORBIDDEN,
+        )
+
+
+class InteractiveAuthIncompleteError(Exception):
+    """An error raised when UI auth is not yet complete
+
+    (This indicates we should return a 401 with 'result' as the body)
+
+    Attributes:
+        result (dict): the server response to the request, which should be
+            passed back to the client
+    """
+    def __init__(self, result):
+        super(InteractiveAuthIncompleteError, self).__init__(
+            "Interactive auth not yet complete",
+        )
+        self.result = result
+
+
 class UnrecognizedRequestError(SynapseError):
     """An error indicating we don't understand the request you're trying to make"""
     def __init__(self, *args, **kwargs):
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 9477737759..e4318cdfc3 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -25,7 +25,9 @@ except Exception:
 from daemonize import Daemonize
 from synapse.util import PreserveLoggingContext
 from synapse.util.rlimit import change_resource_limit
-from twisted.internet import reactor
+from twisted.internet import error, reactor
+
+logger = logging.getLogger(__name__)
 
 
 def start_worker_reactor(appname, config):
@@ -120,3 +122,57 @@ def quit_with_error(error_string):
         sys.stderr.write(" %s\n" % (line.rstrip(),))
     sys.stderr.write("*" * line_length + '\n')
     sys.exit(1)
+
+
+def listen_tcp(bind_addresses, port, factory, backlog=50):
+    """
+    Create a TCP socket for a port and several addresses
+    """
+    for address in bind_addresses:
+        try:
+            reactor.listenTCP(
+                port,
+                factory,
+                backlog,
+                address
+            )
+        except error.CannotListenError as e:
+            check_bind_error(e, address, bind_addresses)
+
+
+def listen_ssl(bind_addresses, port, factory, context_factory, backlog=50):
+    """
+    Create an SSL socket for a port and several addresses
+    """
+    for address in bind_addresses:
+        try:
+            reactor.listenSSL(
+                port,
+                factory,
+                context_factory,
+                backlog,
+                address
+            )
+        except error.CannotListenError as e:
+            check_bind_error(e, address, bind_addresses)
+
+
+def check_bind_error(e, address, bind_addresses):
+    """
+    This method checks an exception occurred while binding on 0.0.0.0.
+    If :: is specified in the bind addresses a warning is shown.
+    The exception is still raised otherwise.
+
+    Binding on both 0.0.0.0 and :: causes an exception on Linux and macOS
+    because :: binds on both IPv4 and IPv6 (as per RFC 3493).
+    When binding on 0.0.0.0 after :: this can safely be ignored.
+
+    Args:
+        e (Exception): Exception that was caught.
+        address (str): Address on which binding was attempted.
+        bind_addresses (list): Addresses on which the service listens.
+    """
+    if address == '0.0.0.0' and '::' in bind_addresses:
+        logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]')
+    else:
+        raise e
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index ba2657bbad..c6fe4516d1 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -49,19 +49,6 @@ class AppserviceSlaveStore(
 
 
 class AppserviceServer(HomeServer):
-    def get_db_conn(self, run_new_connection=True):
-        # Any param beginning with cp_ is a parameter for adbapi, and should
-        # not be passed to the database engine.
-        db_params = {
-            k: v for k, v in self.db_config.get("args", {}).items()
-            if not k.startswith("cp_")
-        }
-        db_conn = self.database_engine.module.connect(**db_params)
-
-        if run_new_connection:
-            self.database_engine.on_new_connection(db_conn)
-        return db_conn
-
     def setup(self):
         logger.info("Setting up.")
         self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
@@ -79,17 +66,16 @@ class AppserviceServer(HomeServer):
 
         root_resource = create_resource_tree(resources, Resource())
 
-        for address in bind_addresses:
-            reactor.listenTCP(
-                port,
-                SynapseSite(
-                    "synapse.access.http.%s" % (site_tag,),
-                    site_tag,
-                    listener_config,
-                    root_resource,
-                ),
-                interface=address
+        _base.listen_tcp(
+            bind_addresses,
+            port,
+            SynapseSite(
+                "synapse.access.http.%s" % (site_tag,),
+                site_tag,
+                listener_config,
+                root_resource,
             )
+        )
 
         logger.info("Synapse appservice now listening on port %d", port)
 
@@ -98,18 +84,15 @@ class AppserviceServer(HomeServer):
             if listener["type"] == "http":
                 self._listen_http(listener)
             elif listener["type"] == "manhole":
-                bind_addresses = listener["bind_addresses"]
-
-                for address in bind_addresses:
-                    reactor.listenTCP(
-                        listener["port"],
-                        manhole(
-                            username="matrix",
-                            password="rabbithole",
-                            globals={"hs": self},
-                        ),
-                        interface=address
+                _base.listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
                     )
+                )
             else:
                 logger.warn("Unrecognized listener type: %s", listener["type"])
 
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 129cfa901f..3b3352798d 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -64,19 +64,6 @@ class ClientReaderSlavedStore(
 
 
 class ClientReaderServer(HomeServer):
-    def get_db_conn(self, run_new_connection=True):
-        # Any param beginning with cp_ is a parameter for adbapi, and should
-        # not be passed to the database engine.
-        db_params = {
-            k: v for k, v in self.db_config.get("args", {}).items()
-            if not k.startswith("cp_")
-        }
-        db_conn = self.database_engine.module.connect(**db_params)
-
-        if run_new_connection:
-            self.database_engine.on_new_connection(db_conn)
-        return db_conn
-
     def setup(self):
         logger.info("Setting up.")
         self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
@@ -103,17 +90,16 @@ class ClientReaderServer(HomeServer):
 
         root_resource = create_resource_tree(resources, Resource())
 
-        for address in bind_addresses:
-            reactor.listenTCP(
-                port,
-                SynapseSite(
-                    "synapse.access.http.%s" % (site_tag,),
-                    site_tag,
-                    listener_config,
-                    root_resource,
-                ),
-                interface=address
+        _base.listen_tcp(
+            bind_addresses,
+            port,
+            SynapseSite(
+                "synapse.access.http.%s" % (site_tag,),
+                site_tag,
+                listener_config,
+                root_resource,
             )
+        )
 
         logger.info("Synapse client reader now listening on port %d", port)
 
@@ -122,18 +108,16 @@ class ClientReaderServer(HomeServer):
             if listener["type"] == "http":
                 self._listen_http(listener)
             elif listener["type"] == "manhole":
-                bind_addresses = listener["bind_addresses"]
-
-                for address in bind_addresses:
-                    reactor.listenTCP(
-                        listener["port"],
-                        manhole(
-                            username="matrix",
-                            password="rabbithole",
-                            globals={"hs": self},
-                        ),
-                        interface=address
+                _base.listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
                     )
+                )
+
             else:
                 logger.warn("Unrecognized listener type: %s", listener["type"])
 
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 40cebe6f4a..4de43c41f0 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -58,19 +58,6 @@ class FederationReaderSlavedStore(
 
 
 class FederationReaderServer(HomeServer):
-    def get_db_conn(self, run_new_connection=True):
-        # Any param beginning with cp_ is a parameter for adbapi, and should
-        # not be passed to the database engine.
-        db_params = {
-            k: v for k, v in self.db_config.get("args", {}).items()
-            if not k.startswith("cp_")
-        }
-        db_conn = self.database_engine.module.connect(**db_params)
-
-        if run_new_connection:
-            self.database_engine.on_new_connection(db_conn)
-        return db_conn
-
     def setup(self):
         logger.info("Setting up.")
         self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
@@ -92,17 +79,16 @@ class FederationReaderServer(HomeServer):
 
         root_resource = create_resource_tree(resources, Resource())
 
-        for address in bind_addresses:
-            reactor.listenTCP(
-                port,
-                SynapseSite(
-                    "synapse.access.http.%s" % (site_tag,),
-                    site_tag,
-                    listener_config,
-                    root_resource,
-                ),
-                interface=address
+        _base.listen_tcp(
+            bind_addresses,
+            port,
+            SynapseSite(
+                "synapse.access.http.%s" % (site_tag,),
+                site_tag,
+                listener_config,
+                root_resource,
             )
+        )
 
         logger.info("Synapse federation reader now listening on port %d", port)
 
@@ -111,18 +97,15 @@ class FederationReaderServer(HomeServer):
             if listener["type"] == "http":
                 self._listen_http(listener)
             elif listener["type"] == "manhole":
-                bind_addresses = listener["bind_addresses"]
-
-                for address in bind_addresses:
-                    reactor.listenTCP(
-                        listener["port"],
-                        manhole(
-                            username="matrix",
-                            password="rabbithole",
-                            globals={"hs": self},
-                        ),
-                        interface=address
+                _base.listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
                     )
+                )
             else:
                 logger.warn("Unrecognized listener type: %s", listener["type"])
 
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 389e3909d1..f760826d27 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -76,19 +76,6 @@ class FederationSenderSlaveStore(
 
 
 class FederationSenderServer(HomeServer):
-    def get_db_conn(self, run_new_connection=True):
-        # Any param beginning with cp_ is a parameter for adbapi, and should
-        # not be passed to the database engine.
-        db_params = {
-            k: v for k, v in self.db_config.get("args", {}).items()
-            if not k.startswith("cp_")
-        }
-        db_conn = self.database_engine.module.connect(**db_params)
-
-        if run_new_connection:
-            self.database_engine.on_new_connection(db_conn)
-        return db_conn
-
     def setup(self):
         logger.info("Setting up.")
         self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
@@ -106,17 +93,16 @@ class FederationSenderServer(HomeServer):
 
         root_resource = create_resource_tree(resources, Resource())
 
-        for address in bind_addresses:
-            reactor.listenTCP(
-                port,
-                SynapseSite(
-                    "synapse.access.http.%s" % (site_tag,),
-                    site_tag,
-                    listener_config,
-                    root_resource,
-                ),
-                interface=address
+        _base.listen_tcp(
+            bind_addresses,
+            port,
+            SynapseSite(
+                "synapse.access.http.%s" % (site_tag,),
+                site_tag,
+                listener_config,
+                root_resource,
             )
+        )
 
         logger.info("Synapse federation_sender now listening on port %d", port)
 
@@ -125,18 +111,15 @@ class FederationSenderServer(HomeServer):
             if listener["type"] == "http":
                 self._listen_http(listener)
             elif listener["type"] == "manhole":
-                bind_addresses = listener["bind_addresses"]
-
-                for address in bind_addresses:
-                    reactor.listenTCP(
-                        listener["port"],
-                        manhole(
-                            username="matrix",
-                            password="rabbithole",
-                            globals={"hs": self},
-                        ),
-                        interface=address
+                _base.listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
                     )
+                )
             else:
                 logger.warn("Unrecognized listener type: %s", listener["type"])
 
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index abc7ef5725..e32ee8fe93 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -118,19 +118,6 @@ class FrontendProxySlavedStore(
 
 
 class FrontendProxyServer(HomeServer):
-    def get_db_conn(self, run_new_connection=True):
-        # Any param beginning with cp_ is a parameter for adbapi, and should
-        # not be passed to the database engine.
-        db_params = {
-            k: v for k, v in self.db_config.get("args", {}).items()
-            if not k.startswith("cp_")
-        }
-        db_conn = self.database_engine.module.connect(**db_params)
-
-        if run_new_connection:
-            self.database_engine.on_new_connection(db_conn)
-        return db_conn
-
     def setup(self):
         logger.info("Setting up.")
         self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self)
@@ -157,17 +144,16 @@ class FrontendProxyServer(HomeServer):
 
         root_resource = create_resource_tree(resources, Resource())
 
-        for address in bind_addresses:
-            reactor.listenTCP(
-                port,
-                SynapseSite(
-                    "synapse.access.http.%s" % (site_tag,),
-                    site_tag,
-                    listener_config,
-                    root_resource,
-                ),
-                interface=address
+        _base.listen_tcp(
+            bind_addresses,
+            port,
+            SynapseSite(
+                "synapse.access.http.%s" % (site_tag,),
+                site_tag,
+                listener_config,
+                root_resource,
             )
+        )
 
         logger.info("Synapse client reader now listening on port %d", port)
 
@@ -176,18 +162,15 @@ class FrontendProxyServer(HomeServer):
             if listener["type"] == "http":
                 self._listen_http(listener)
             elif listener["type"] == "manhole":
-                bind_addresses = listener["bind_addresses"]
-
-                for address in bind_addresses:
-                    reactor.listenTCP(
-                        listener["port"],
-                        manhole(
-                            username="matrix",
-                            password="rabbithole",
-                            globals={"hs": self},
-                        ),
-                        interface=address
+                _base.listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
                     )
+                )
             else:
                 logger.warn("Unrecognized listener type: %s", listener["type"])
 
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 9e26146338..cb82a415a6 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -25,7 +25,7 @@ from synapse.api.urls import CONTENT_REPO_PREFIX, FEDERATION_PREFIX, \
     LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, SERVER_KEY_PREFIX, SERVER_KEY_V2_PREFIX, \
     STATIC_PREFIX, WEB_CLIENT_PREFIX
 from synapse.app import _base
-from synapse.app._base import quit_with_error
+from synapse.app._base import quit_with_error, listen_ssl, listen_tcp
 from synapse.config._base import ConfigError
 from synapse.config.homeserver import HomeServerConfig
 from synapse.crypto import context_factory
@@ -43,7 +43,6 @@ from synapse.rest import ClientRestResource
 from synapse.rest.key.v1.server_key_resource import LocalKey
 from synapse.rest.key.v2 import KeyApiV2Resource
 from synapse.rest.media.v0.content_repository import ContentRepoResource
-from synapse.rest.media.v1.media_repository import MediaRepositoryResource
 from synapse.server import HomeServer
 from synapse.storage import are_all_users_on_domain
 from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
@@ -131,30 +130,29 @@ class SynapseHomeServer(HomeServer):
         root_resource = create_resource_tree(resources, root_resource)
 
         if tls:
-            for address in bind_addresses:
-                reactor.listenSSL(
-                    port,
-                    SynapseSite(
-                        "synapse.access.https.%s" % (site_tag,),
-                        site_tag,
-                        listener_config,
-                        root_resource,
-                    ),
-                    self.tls_server_context_factory,
-                    interface=address
-                )
+            listen_ssl(
+                bind_addresses,
+                port,
+                SynapseSite(
+                    "synapse.access.https.%s" % (site_tag,),
+                    site_tag,
+                    listener_config,
+                    root_resource,
+                ),
+                self.tls_server_context_factory,
+            )
+
         else:
-            for address in bind_addresses:
-                reactor.listenTCP(
-                    port,
-                    SynapseSite(
-                        "synapse.access.http.%s" % (site_tag,),
-                        site_tag,
-                        listener_config,
-                        root_resource,
-                    ),
-                    interface=address
+            listen_tcp(
+                bind_addresses,
+                port,
+                SynapseSite(
+                    "synapse.access.http.%s" % (site_tag,),
+                    site_tag,
+                    listener_config,
+                    root_resource,
                 )
+            )
         logger.info("Synapse now listening on port %d", port)
 
     def _configure_named_resource(self, name, compress=False):
@@ -195,14 +193,19 @@ class SynapseHomeServer(HomeServer):
             })
 
         if name in ["media", "federation", "client"]:
-            media_repo = MediaRepositoryResource(self)
-            resources.update({
-                MEDIA_PREFIX: media_repo,
-                LEGACY_MEDIA_PREFIX: media_repo,
-                CONTENT_REPO_PREFIX: ContentRepoResource(
-                    self, self.config.uploads_path
-                ),
-            })
+            if self.get_config().enable_media_repo:
+                media_repo = self.get_media_repository_resource()
+                resources.update({
+                    MEDIA_PREFIX: media_repo,
+                    LEGACY_MEDIA_PREFIX: media_repo,
+                    CONTENT_REPO_PREFIX: ContentRepoResource(
+                        self, self.config.uploads_path
+                    ),
+                })
+            elif name == "media":
+                raise ConfigError(
+                    "'media' resource conflicts with enable_media_repo=False",
+                )
 
         if name in ["keys", "federation"]:
             resources.update({
@@ -225,18 +228,15 @@ class SynapseHomeServer(HomeServer):
             if listener["type"] == "http":
                 self._listener_http(config, listener)
             elif listener["type"] == "manhole":
-                bind_addresses = listener["bind_addresses"]
-
-                for address in bind_addresses:
-                    reactor.listenTCP(
-                        listener["port"],
-                        manhole(
-                            username="matrix",
-                            password="rabbithole",
-                            globals={"hs": self},
-                        ),
-                        interface=address
+                listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
                     )
+                )
             elif listener["type"] == "replication":
                 bind_addresses = listener["bind_addresses"]
                 for address in bind_addresses:
@@ -266,19 +266,6 @@ class SynapseHomeServer(HomeServer):
         except IncorrectDatabaseSetup as e:
             quit_with_error(e.message)
 
-    def get_db_conn(self, run_new_connection=True):
-        # Any param beginning with cp_ is a parameter for adbapi, and should
-        # not be passed to the database engine.
-        db_params = {
-            k: v for k, v in self.db_config.get("args", {}).items()
-            if not k.startswith("cp_")
-        }
-        db_conn = self.database_engine.module.connect(**db_params)
-
-        if run_new_connection:
-            self.database_engine.on_new_connection(db_conn)
-        return db_conn
-
 
 def setup(config_options):
     """
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index 36c18bdbcb..1ed1ca8772 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -35,7 +35,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
 from synapse.replication.slave.storage.transactions import TransactionStore
 from synapse.replication.tcp.client import ReplicationClientHandler
 from synapse.rest.media.v0.content_repository import ContentRepoResource
-from synapse.rest.media.v1.media_repository import MediaRepositoryResource
 from synapse.server import HomeServer
 from synapse.storage.engines import create_engine
 from synapse.storage.media_repository import MediaRepositoryStore
@@ -61,19 +60,6 @@ class MediaRepositorySlavedStore(
 
 
 class MediaRepositoryServer(HomeServer):
-    def get_db_conn(self, run_new_connection=True):
-        # Any param beginning with cp_ is a parameter for adbapi, and should
-        # not be passed to the database engine.
-        db_params = {
-            k: v for k, v in self.db_config.get("args", {}).items()
-            if not k.startswith("cp_")
-        }
-        db_conn = self.database_engine.module.connect(**db_params)
-
-        if run_new_connection:
-            self.database_engine.on_new_connection(db_conn)
-        return db_conn
-
     def setup(self):
         logger.info("Setting up.")
         self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
@@ -89,7 +75,7 @@ class MediaRepositoryServer(HomeServer):
                 if name == "metrics":
                     resources[METRICS_PREFIX] = MetricsResource(self)
                 elif name == "media":
-                    media_repo = MediaRepositoryResource(self)
+                    media_repo = self.get_media_repository_resource()
                     resources.update({
                         MEDIA_PREFIX: media_repo,
                         LEGACY_MEDIA_PREFIX: media_repo,
@@ -100,17 +86,16 @@ class MediaRepositoryServer(HomeServer):
 
         root_resource = create_resource_tree(resources, Resource())
 
-        for address in bind_addresses:
-            reactor.listenTCP(
-                port,
-                SynapseSite(
-                    "synapse.access.http.%s" % (site_tag,),
-                    site_tag,
-                    listener_config,
-                    root_resource,
-                ),
-                interface=address
+        _base.listen_tcp(
+            bind_addresses,
+            port,
+            SynapseSite(
+                "synapse.access.http.%s" % (site_tag,),
+                site_tag,
+                listener_config,
+                root_resource,
             )
+        )
 
         logger.info("Synapse media repository now listening on port %d", port)
 
@@ -119,18 +104,15 @@ class MediaRepositoryServer(HomeServer):
             if listener["type"] == "http":
                 self._listen_http(listener)
             elif listener["type"] == "manhole":
-                bind_addresses = listener["bind_addresses"]
-
-                for address in bind_addresses:
-                    reactor.listenTCP(
-                        listener["port"],
-                        manhole(
-                            username="matrix",
-                            password="rabbithole",
-                            globals={"hs": self},
-                        ),
-                        interface=address
+                _base.listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
                     )
+                )
             else:
                 logger.warn("Unrecognized listener type: %s", listener["type"])
 
@@ -151,6 +133,13 @@ def start(config_options):
 
     assert config.worker_app == "synapse.app.media_repository"
 
+    if config.enable_media_repo:
+        _base.quit_with_error(
+            "enable_media_repo must be disabled in the main synapse process\n"
+            "before the media repo can be run in a separate worker.\n"
+            "Please add ``enable_media_repo: false`` to the main config\n"
+        )
+
     setup_logging(config, use_worker_options=True)
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index db9a4d16f4..32ccea3f13 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -81,19 +81,6 @@ class PusherSlaveStore(
 
 
 class PusherServer(HomeServer):
-    def get_db_conn(self, run_new_connection=True):
-        # Any param beginning with cp_ is a parameter for adbapi, and should
-        # not be passed to the database engine.
-        db_params = {
-            k: v for k, v in self.db_config.get("args", {}).items()
-            if not k.startswith("cp_")
-        }
-        db_conn = self.database_engine.module.connect(**db_params)
-
-        if run_new_connection:
-            self.database_engine.on_new_connection(db_conn)
-        return db_conn
-
     def setup(self):
         logger.info("Setting up.")
         self.datastore = PusherSlaveStore(self.get_db_conn(), self)
@@ -114,17 +101,16 @@ class PusherServer(HomeServer):
 
         root_resource = create_resource_tree(resources, Resource())
 
-        for address in bind_addresses:
-            reactor.listenTCP(
-                port,
-                SynapseSite(
-                    "synapse.access.http.%s" % (site_tag,),
-                    site_tag,
-                    listener_config,
-                    root_resource,
-                ),
-                interface=address
+        _base.listen_tcp(
+            bind_addresses,
+            port,
+            SynapseSite(
+                "synapse.access.http.%s" % (site_tag,),
+                site_tag,
+                listener_config,
+                root_resource,
             )
+        )
 
         logger.info("Synapse pusher now listening on port %d", port)
 
@@ -133,18 +119,15 @@ class PusherServer(HomeServer):
             if listener["type"] == "http":
                 self._listen_http(listener)
             elif listener["type"] == "manhole":
-                bind_addresses = listener["bind_addresses"]
-
-                for address in bind_addresses:
-                    reactor.listenTCP(
-                        listener["port"],
-                        manhole(
-                            username="matrix",
-                            password="rabbithole",
-                            globals={"hs": self},
-                        ),
-                        interface=address
+                _base.listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
                     )
+                )
             else:
                 logger.warn("Unrecognized listener type: %s", listener["type"])
 
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 576ac6fb7e..f87531f1b6 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -246,19 +246,6 @@ class SynchrotronApplicationService(object):
 
 
 class SynchrotronServer(HomeServer):
-    def get_db_conn(self, run_new_connection=True):
-        # Any param beginning with cp_ is a parameter for adbapi, and should
-        # not be passed to the database engine.
-        db_params = {
-            k: v for k, v in self.db_config.get("args", {}).items()
-            if not k.startswith("cp_")
-        }
-        db_conn = self.database_engine.module.connect(**db_params)
-
-        if run_new_connection:
-            self.database_engine.on_new_connection(db_conn)
-        return db_conn
-
     def setup(self):
         logger.info("Setting up.")
         self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
@@ -288,17 +275,16 @@ class SynchrotronServer(HomeServer):
 
         root_resource = create_resource_tree(resources, Resource())
 
-        for address in bind_addresses:
-            reactor.listenTCP(
-                port,
-                SynapseSite(
-                    "synapse.access.http.%s" % (site_tag,),
-                    site_tag,
-                    listener_config,
-                    root_resource,
-                ),
-                interface=address
+        _base.listen_tcp(
+            bind_addresses,
+            port,
+            SynapseSite(
+                "synapse.access.http.%s" % (site_tag,),
+                site_tag,
+                listener_config,
+                root_resource,
             )
+        )
 
         logger.info("Synapse synchrotron now listening on port %d", port)
 
@@ -307,18 +293,15 @@ class SynchrotronServer(HomeServer):
             if listener["type"] == "http":
                 self._listen_http(listener)
             elif listener["type"] == "manhole":
-                bind_addresses = listener["bind_addresses"]
-
-                for address in bind_addresses:
-                    reactor.listenTCP(
-                        listener["port"],
-                        manhole(
-                            username="matrix",
-                            password="rabbithole",
-                            globals={"hs": self},
-                        ),
-                        interface=address
+                _base.listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
                     )
+                )
             else:
                 logger.warn("Unrecognized listener type: %s", listener["type"])
 
@@ -340,11 +323,10 @@ class SyncReplicationHandler(ReplicationClientHandler):
 
         self.store = hs.get_datastore()
         self.typing_handler = hs.get_typing_handler()
+        # NB this is a SynchrotronPresence, not a normal PresenceHandler
         self.presence_handler = hs.get_presence_handler()
         self.notifier = hs.get_notifier()
 
-        self.presence_handler.sync_callback = self.send_user_sync
-
     def on_rdata(self, stream_name, token, rows):
         super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
 
diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py
index 3bd7ef7bba..0f0ddfa78a 100755
--- a/synapse/app/synctl.py
+++ b/synapse/app/synctl.py
@@ -184,6 +184,9 @@ def main():
         worker_configfiles.append(worker_configfile)
 
     if options.all_processes:
+        # To start the main synapse with -a you need to add a worker file
+        # with worker_app == "synapse.app.homeserver"
+        start_stop_synapse = False
         worker_configdir = options.all_processes
         if not os.path.isdir(worker_configdir):
             write(
@@ -200,11 +203,29 @@ def main():
         with open(worker_configfile) as stream:
             worker_config = yaml.load(stream)
         worker_app = worker_config["worker_app"]
-        worker_pidfile = worker_config["worker_pid_file"]
-        worker_daemonize = worker_config["worker_daemonize"]
-        assert worker_daemonize, "In config %r: expected '%s' to be True" % (
-            worker_configfile, "worker_daemonize")
-        worker_cache_factor = worker_config.get("synctl_cache_factor")
+        if worker_app == "synapse.app.homeserver":
+            # We need to special case all of this to pick up options that may
+            # be set in the main config file or in this worker config file.
+            worker_pidfile = (
+                worker_config.get("pid_file")
+                or pidfile
+            )
+            worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor
+            daemonize = worker_config.get("daemonize") or config.get("daemonize")
+            assert daemonize, "Main process must have daemonize set to true"
+
+            # The master process doesn't support using worker_* config.
+            for key in worker_config:
+                if key == "worker_app":  # But we allow worker_app
+                    continue
+                assert not key.startswith("worker_"), \
+                    "Main process cannot use worker_* config"
+        else:
+            worker_pidfile = worker_config["worker_pid_file"]
+            worker_daemonize = worker_config["worker_daemonize"]
+            assert worker_daemonize, "In config %r: expected '%s' to be True" % (
+                worker_configfile, "worker_daemonize")
+            worker_cache_factor = worker_config.get("synctl_cache_factor")
         workers.append(Worker(
             worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
         ))
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index be661a70c7..494ccb702c 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -92,19 +92,6 @@ class UserDirectorySlaveStore(
 
 
 class UserDirectoryServer(HomeServer):
-    def get_db_conn(self, run_new_connection=True):
-        # Any param beginning with cp_ is a parameter for adbapi, and should
-        # not be passed to the database engine.
-        db_params = {
-            k: v for k, v in self.db_config.get("args", {}).items()
-            if not k.startswith("cp_")
-        }
-        db_conn = self.database_engine.module.connect(**db_params)
-
-        if run_new_connection:
-            self.database_engine.on_new_connection(db_conn)
-        return db_conn
-
     def setup(self):
         logger.info("Setting up.")
         self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
@@ -131,17 +118,16 @@ class UserDirectoryServer(HomeServer):
 
         root_resource = create_resource_tree(resources, Resource())
 
-        for address in bind_addresses:
-            reactor.listenTCP(
-                port,
-                SynapseSite(
-                    "synapse.access.http.%s" % (site_tag,),
-                    site_tag,
-                    listener_config,
-                    root_resource,
-                ),
-                interface=address
+        _base.listen_tcp(
+            bind_addresses,
+            port,
+            SynapseSite(
+                "synapse.access.http.%s" % (site_tag,),
+                site_tag,
+                listener_config,
+                root_resource,
             )
+        )
 
         logger.info("Synapse user_dir now listening on port %d", port)
 
@@ -150,18 +136,15 @@ class UserDirectoryServer(HomeServer):
             if listener["type"] == "http":
                 self._listen_http(listener)
             elif listener["type"] == "manhole":
-                bind_addresses = listener["bind_addresses"]
-
-                for address in bind_addresses:
-                    reactor.listenTCP(
-                        listener["port"],
-                        manhole(
-                            username="matrix",
-                            password="rabbithole",
-                            globals={"hs": self},
-                        ),
-                        interface=address
+                _base.listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
                     )
+                )
             else:
                 logger.warn("Unrecognized listener type: %s", listener["type"])
 
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 05e242aef6..bf19cfee29 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -36,6 +36,7 @@ from .workers import WorkerConfig
 from .push import PushConfig
 from .spam_checker import SpamCheckerConfig
 from .groups import GroupsConfig
+from .user_directory import UserDirectoryConfig
 
 
 class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
@@ -44,7 +45,7 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
                        AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
                        JWTConfig, PasswordConfig, EmailConfig,
                        WorkerConfig, PasswordAuthProviderConfig, PushConfig,
-                       SpamCheckerConfig, GroupsConfig,):
+                       SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,):
     pass
 
 
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index a1d6e4d4f7..3f70039acd 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -28,27 +28,27 @@ DEFAULT_LOG_CONFIG = Template("""
 version: 1
 
 formatters:
-  precise:
-   format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\
-- %(message)s'
+    precise:
+        format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \
+%(request)s - %(message)s'
 
 filters:
-  context:
-    (): synapse.util.logcontext.LoggingContextFilter
-    request: ""
+    context:
+        (): synapse.util.logcontext.LoggingContextFilter
+        request: ""
 
 handlers:
-  file:
-    class: logging.handlers.RotatingFileHandler
-    formatter: precise
-    filename: ${log_file}
-    maxBytes: 104857600
-    backupCount: 10
-    filters: [context]
-  console:
-    class: logging.StreamHandler
-    formatter: precise
-    filters: [context]
+    file:
+        class: logging.handlers.RotatingFileHandler
+        formatter: precise
+        filename: ${log_file}
+        maxBytes: 104857600
+        backupCount: 10
+        filters: [context]
+    console:
+        class: logging.StreamHandler
+        formatter: precise
+        filters: [context]
 
 loggers:
     synapse:
@@ -74,17 +74,10 @@ class LoggingConfig(Config):
         self.log_file = self.abspath(config.get("log_file"))
 
     def default_config(self, config_dir_path, server_name, **kwargs):
-        log_file = self.abspath("homeserver.log")
         log_config = self.abspath(
             os.path.join(config_dir_path, server_name + ".log.config")
         )
         return """
-        # Logging verbosity level. Ignored if log_config is specified.
-        verbose: 0
-
-        # File to write logging to. Ignored if log_config is specified.
-        log_file: "%(log_file)s"
-
         # A yaml python logging config file
         log_config: "%(log_config)s"
         """ % locals()
@@ -123,9 +116,10 @@ class LoggingConfig(Config):
     def generate_files(self, config):
         log_config = config.get("log_config")
         if log_config and not os.path.exists(log_config):
+            log_file = self.abspath("homeserver.log")
             with open(log_config, "wb") as log_config_file:
                 log_config_file.write(
-                    DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
+                    DEFAULT_LOG_CONFIG.substitute(log_file=log_file)
                 )
 
 
@@ -150,6 +144,9 @@ def setup_logging(config, use_worker_options=False):
     )
 
     if log_config is None:
+        # We don't have a logfile, so fall back to the 'verbosity' param from
+        # the config or cmdline. (Note that we generate a log config for new
+        # installs, so this will be an unusual case)
         level = logging.INFO
         level_for_storage = logging.INFO
         if config.verbosity:
@@ -157,11 +154,10 @@ def setup_logging(config, use_worker_options=False):
             if config.verbosity > 1:
                 level_for_storage = logging.DEBUG
 
-        # FIXME: we need a logging.WARN for a -q quiet option
         logger = logging.getLogger('')
         logger.setLevel(level)
 
-        logging.getLogger('synapse.storage').setLevel(level_for_storage)
+        logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage)
 
         formatter = logging.Formatter(log_format)
         if log_file:
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index e9828fac17..6602c5b4c7 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -29,10 +29,10 @@ class PasswordAuthProviderConfig(Config):
         # param.
         ldap_config = config.get("ldap_config", {})
         if ldap_config.get("enabled", False):
-            providers.append[{
+            providers.append({
                 'module': LDAP_PROVIDER,
                 'config': ldap_config,
-            }]
+            })
 
         providers.extend(config.get("password_providers", []))
         for provider in providers:
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index ef917fc9f2..336959094b 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -31,6 +31,8 @@ class RegistrationConfig(Config):
                 strtobool(str(config["disable_registration"]))
             )
 
+        self.registrations_require_3pid = config.get("registrations_require_3pid", [])
+        self.allowed_local_3pids = config.get("allowed_local_3pids", [])
         self.registration_shared_secret = config.get("registration_shared_secret")
 
         self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
@@ -52,6 +54,23 @@ class RegistrationConfig(Config):
         # Enable registration for new users.
         enable_registration: False
 
+        # The user must provide all of the below types of 3PID when registering.
+        #
+        # registrations_require_3pid:
+        #     - email
+        #     - msisdn
+
+        # Mandate that users are only allowed to associate certain formats of
+        # 3PIDs with accounts on this server.
+        #
+        # allowed_local_3pids:
+        #     - medium: email
+        #       pattern: ".*@matrix\\.org"
+        #     - medium: email
+        #       pattern: ".*@vector\\.im"
+        #     - medium: msisdn
+        #       pattern: "\\+44"
+
         # If set, allows registration by anyone who also has the shared
         # secret, even if registration is otherwise disabled.
         registration_shared_secret: "%(registration_shared_secret)s"
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 6baa474931..25ea77738a 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -16,6 +16,8 @@
 from ._base import Config, ConfigError
 from collections import namedtuple
 
+from synapse.util.module_loader import load_module
+
 
 MISSING_NETADDR = (
     "Missing netaddr library. This is required for URL preview API."
@@ -36,6 +38,14 @@ ThumbnailRequirement = namedtuple(
     "ThumbnailRequirement", ["width", "height", "method", "media_type"]
 )
 
+MediaStorageProviderConfig = namedtuple(
+    "MediaStorageProviderConfig", (
+        "store_local",  # Whether to store newly uploaded local files
+        "store_remote",  # Whether to store newly downloaded remote files
+        "store_synchronous",  # Whether to wait for successful storage for local uploads
+    ),
+)
+
 
 def parse_thumbnail_requirements(thumbnail_sizes):
     """ Takes a list of dictionaries with "width", "height", and "method" keys
@@ -73,16 +83,61 @@ class ContentRepositoryConfig(Config):
 
         self.media_store_path = self.ensure_directory(config["media_store_path"])
 
-        self.backup_media_store_path = config.get("backup_media_store_path")
-        if self.backup_media_store_path:
-            self.backup_media_store_path = self.ensure_directory(
-                self.backup_media_store_path
-            )
+        backup_media_store_path = config.get("backup_media_store_path")
 
-        self.synchronous_backup_media_store = config.get(
+        synchronous_backup_media_store = config.get(
             "synchronous_backup_media_store", False
         )
 
+        storage_providers = config.get("media_storage_providers", [])
+
+        if backup_media_store_path:
+            if storage_providers:
+                raise ConfigError(
+                    "Cannot use both 'backup_media_store_path' and 'storage_providers'"
+                )
+
+            storage_providers = [{
+                "module": "file_system",
+                "store_local": True,
+                "store_synchronous": synchronous_backup_media_store,
+                "store_remote": True,
+                "config": {
+                    "directory": backup_media_store_path,
+                }
+            }]
+
+        # This is a list of config that can be used to create the storage
+        # providers. The entries are tuples of (Class, class_config,
+        # MediaStorageProviderConfig), where Class is the class of the provider,
+        # the class_config the config to pass to it, and
+        # MediaStorageProviderConfig are options for StorageProviderWrapper.
+        #
+        # We don't create the storage providers here as not all workers need
+        # them to be started.
+        self.media_storage_providers = []
+
+        for provider_config in storage_providers:
+            # We special case the module "file_system" so as not to need to
+            # expose FileStorageProviderBackend
+            if provider_config["module"] == "file_system":
+                provider_config["module"] = (
+                    "synapse.rest.media.v1.storage_provider"
+                    ".FileStorageProviderBackend"
+                )
+
+            provider_class, parsed_config = load_module(provider_config)
+
+            wrapper_config = MediaStorageProviderConfig(
+                provider_config.get("store_local", False),
+                provider_config.get("store_remote", False),
+                provider_config.get("store_synchronous", False),
+            )
+
+            self.media_storage_providers.append(
+                (provider_class, parsed_config, wrapper_config,)
+            )
+
         self.uploads_path = self.ensure_directory(config["uploads_path"])
         self.dynamic_thumbnails = config["dynamic_thumbnails"]
         self.thumbnail_requirements = parse_thumbnail_requirements(
@@ -127,13 +182,19 @@ class ContentRepositoryConfig(Config):
         # Directory where uploaded images and attachments are stored.
         media_store_path: "%(media_store)s"
 
-        # A secondary directory where uploaded images and attachments are
-        # stored as a backup.
-        # backup_media_store_path: "%(media_store)s"
-
-        # Whether to wait for successful write to backup media store before
-        # returning successfully.
-        # synchronous_backup_media_store: false
+        # Media storage providers allow media to be stored in different
+        # locations.
+        # media_storage_providers:
+        # - module: file_system
+        #   # Whether to write new local files.
+        #   store_local: false
+        #   # Whether to write new remote media
+        #   store_remote: false
+        #   # Whether to block upload requests waiting for write to this
+        #   # provider to complete
+        #   store_synchronous: false
+        #   config:
+        #     directory: /mnt/some/other/directory
 
         # Directory where in-progress uploads are stored.
         uploads_path: "%(uploads_path)s"
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 4d9193536d..8f0b6d1f28 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -41,6 +41,12 @@ class ServerConfig(Config):
         # false only if we are updating the user directory in a worker
         self.update_user_directory = config.get("update_user_directory", True)
 
+        # whether to enable the media repository endpoints. This should be set
+        # to false if the media repository is running as a separate endpoint;
+        # doing so ensures that we will not run cache cleanup jobs on the
+        # master, potentially causing inconsistency.
+        self.enable_media_repo = config.get("enable_media_repo", True)
+
         self.filter_timeline_limit = config.get("filter_timeline_limit", -1)
 
         # Whether we should block invites sent to users on this server
@@ -49,6 +55,17 @@ class ServerConfig(Config):
             "block_non_admin_invites", False,
         )
 
+        # FIXME: federation_domain_whitelist needs sytests
+        self.federation_domain_whitelist = None
+        federation_domain_whitelist = config.get(
+            "federation_domain_whitelist", None
+        )
+        # turn the whitelist into a hash for speed of lookup
+        if federation_domain_whitelist is not None:
+            self.federation_domain_whitelist = {}
+            for domain in federation_domain_whitelist:
+                self.federation_domain_whitelist[domain] = True
+
         if self.public_baseurl is not None:
             if self.public_baseurl[-1] != '/':
                 self.public_baseurl += '/'
@@ -204,6 +221,17 @@ class ServerConfig(Config):
         # (except those sent by local server admins). The default is False.
         # block_non_admin_invites: True
 
+        # Restrict federation to the following whitelist of domains.
+        # N.B. we recommend also firewalling your federation listener to limit
+        # inbound federation traffic as early as possible, rather than relying
+        # purely on this application-layer restriction.  If not specified, the
+        # default is to whitelist everything.
+        #
+        # federation_domain_whitelist:
+        #  - lon.example.com
+        #  - nyc.example.com
+        #  - syd.example.com
+
         # List of ports that Synapse should listen on, their purpose and their
         # configuration.
         listeners:
@@ -214,13 +242,12 @@ class ServerConfig(Config):
             port: %(bind_port)s
 
             # Local addresses to listen on.
-            # This will listen on all IPv4 addresses by default.
+            # On Linux and Mac OS, `::` will listen on all IPv4 and IPv6
+            # addresses by default. For most other OSes, this will only listen
+            # on IPv6.
             bind_addresses:
+              - '::'
               - '0.0.0.0'
-              # Uncomment to listen on all IPv6 interfaces
-              # N.B: On at least Linux this will also listen on all IPv4
-              # addresses, so you will need to comment out the line above.
-              # - '::'
 
             # This is a 'http' listener, allows us to specify 'resources'.
             type: http
@@ -258,7 +285,7 @@ class ServerConfig(Config):
           # For when matrix traffic passes through loadbalancer that unwraps TLS.
           - port: %(unsecure_port)s
             tls: false
-            bind_addresses: ['0.0.0.0']
+            bind_addresses: ['::', '0.0.0.0']
             type: http
 
             x_forwarded: false
@@ -272,7 +299,7 @@ class ServerConfig(Config):
           # Turn on the twisted ssh manhole service on localhost on the given
           # port.
           # - port: 9000
-          #   bind_address: 127.0.0.1
+          #   bind_addresses: ['::1', '127.0.0.1']
           #   type: manhole
         """ % locals()
 
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 4748f71c2f..29eb012ddb 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -96,7 +96,7 @@ class TlsConfig(Config):
         # certificates returned by this server match one of the fingerprints.
         #
         # Synapse automatically adds the fingerprint of its own certificate
-        # to the list. So if federation traffic is handle directly by synapse
+        # to the list. So if federation traffic is handled directly by synapse
         # then no modification to the list is required.
         #
         # If synapse is run behind a load balancer that handles the TLS then it
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
new file mode 100644
index 0000000000..38e8947843
--- /dev/null
+++ b/synapse/config/user_directory.py
@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config
+
+
+class UserDirectoryConfig(Config):
+    """User Directory Configuration
+    Configuration for the behaviour of the /user_directory API
+    """
+
+    def read_config(self, config):
+        self.user_directory_search_all_users = False
+        user_directory_config = config.get("user_directory", None)
+        if user_directory_config:
+            self.user_directory_search_all_users = (
+                user_directory_config.get("search_all_users", False)
+            )
+
+    def default_config(self, config_dir_path, server_name, **kwargs):
+        return """
+        # User Directory configuration
+        #
+        # 'search_all_users' defines whether to search all users visible to your HS
+        # when searching the user directory, rather than limiting to users visible
+        # in public rooms.  Defaults to false.  If you set it True, you'll have to run
+        # UPDATE user_directory_stream_pos SET stream_id = NULL;
+        # on your database to tell it to rebuild the user_directory search indexes.
+        #
+        #user_directory:
+        #   search_all_users: false
+        """
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index c5a5a8919c..4b6884918d 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -23,6 +23,11 @@ class WorkerConfig(Config):
 
     def read_config(self, config):
         self.worker_app = config.get("worker_app")
+
+        # Canonicalise worker_app so that master always has None
+        if self.worker_app == "synapse.app.homeserver":
+            self.worker_app = None
+
         self.worker_listeners = config.get("worker_listeners")
         self.worker_daemonize = config.get("worker_daemonize")
         self.worker_pid_file = config.get("worker_pid_file")
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 0d0e7b5286..aaa3efaca3 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -32,15 +32,22 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
     """Check whether the hash for this PDU matches the contents"""
     name, expected_hash = compute_content_hash(event, hash_algorithm)
     logger.debug("Expecting hash: %s", encode_base64(expected_hash))
-    if name not in event.hashes:
+
+    # some malformed events lack a 'hashes'. Protect against it being missing
+    # or a weird type by basically treating it the same as an unhashed event.
+    hashes = event.get("hashes")
+    if not isinstance(hashes, dict):
+        raise SynapseError(400, "Malformed 'hashes'", Codes.UNAUTHORIZED)
+
+    if name not in hashes:
         raise SynapseError(
             400,
             "Algorithm %s not in hashes %s" % (
-                name, list(event.hashes),
+                name, list(hashes),
             ),
             Codes.UNAUTHORIZED,
         )
-    message_hash_base64 = event.hashes[name]
+    message_hash_base64 = hashes[name]
     try:
         message_hash_bytes = decode_base64(message_hash_base64)
     except Exception:
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 061ee86b16..cd5627e36a 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -319,7 +319,7 @@ def _is_membership_change_allowed(event, auth_events):
         # TODO (erikj): Implement kicks.
         if target_banned and user_level < ban_level:
             raise AuthError(
-                403, "You cannot unban user &s." % (target_user_id,)
+                403, "You cannot unban user %s." % (target_user_id,)
             )
         elif target_user_id != event.user_id:
             kick_level = _get_named_level(auth_events, "kick", 50)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index e9a732ff03..87e3fe7b97 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -25,7 +25,9 @@ class EventContext(object):
             The current state map excluding the current event.
             (type, state_key) -> event_id
 
-        state_group (int): state group id
+        state_group (int|None): state group id, if the state has been stored
+            as a state group. This is usually only None if e.g. the event is
+            an outlier.
         rejected (bool|str): A rejection reason if the event was rejected, else
             False
 
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index a0f5d40eb3..7918d3e442 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -16,7 +16,9 @@ import logging
 
 from synapse.api.errors import SynapseError
 from synapse.crypto.event_signing import check_event_content_hash
+from synapse.events import FrozenEvent
 from synapse.events.utils import prune_event
+from synapse.http.servlet import assert_params_in_request
 from synapse.util import unwrapFirstError, logcontext
 from twisted.internet import defer
 
@@ -169,3 +171,28 @@ class FederationBase(object):
             )
 
         return deferreds
+
+
+def event_from_pdu_json(pdu_json, outlier=False):
+    """Construct a FrozenEvent from an event json received over federation
+
+    Args:
+        pdu_json (object): pdu as received over federation
+        outlier (bool): True to mark this event as an outlier
+
+    Returns:
+        FrozenEvent
+
+    Raises:
+        SynapseError: if the pdu is missing required fields
+    """
+    # we could probably enforce a bunch of other fields here (room_id, sender,
+    # origin, etc etc)
+    assert_params_in_request(pdu_json, ('event_id', 'type'))
+    event = FrozenEvent(
+        pdu_json
+    )
+
+    event.internal_metadata.outlier = outlier
+
+    return event
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index b8f02f5391..813907f7f2 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -14,29 +14,29 @@
 # limitations under the License.
 
 
+import copy
+import itertools
+import logging
+import random
+
 from twisted.internet import defer
 
-from .federation_base import FederationBase
 from synapse.api.constants import Membership
-
 from synapse.api.errors import (
-    CodeMessageException, HttpResponseException, SynapseError,
+    CodeMessageException, HttpResponseException, SynapseError, FederationDeniedError
 )
-from synapse.util import unwrapFirstError, logcontext
+from synapse.events import builder
+from synapse.federation.federation_base import (
+    FederationBase,
+    event_from_pdu_json,
+)
+import synapse.metrics
+from synapse.util import logcontext, unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.logutils import log_function
 from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
-from synapse.events import FrozenEvent, builder
-import synapse.metrics
-
+from synapse.util.logutils import log_function
 from synapse.util.retryutils import NotRetryingDestination
 
-import copy
-import itertools
-import logging
-import random
-
-
 logger = logging.getLogger(__name__)
 
 
@@ -184,7 +184,7 @@ class FederationClient(FederationBase):
         logger.debug("backfill transaction_data=%s", repr(transaction_data))
 
         pdus = [
-            self.event_from_pdu_json(p, outlier=False)
+            event_from_pdu_json(p, outlier=False)
             for p in transaction_data["pdus"]
         ]
 
@@ -244,7 +244,7 @@ class FederationClient(FederationBase):
                 logger.debug("transaction_data %r", transaction_data)
 
                 pdu_list = [
-                    self.event_from_pdu_json(p, outlier=outlier)
+                    event_from_pdu_json(p, outlier=outlier)
                     for p in transaction_data["pdus"]
                 ]
 
@@ -266,6 +266,9 @@ class FederationClient(FederationBase):
             except NotRetryingDestination as e:
                 logger.info(e.message)
                 continue
+            except FederationDeniedError as e:
+                logger.info(e.message)
+                continue
             except Exception as e:
                 pdu_attempts[destination] = now
 
@@ -336,11 +339,11 @@ class FederationClient(FederationBase):
         )
 
         pdus = [
-            self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
+            event_from_pdu_json(p, outlier=True) for p in result["pdus"]
         ]
 
         auth_chain = [
-            self.event_from_pdu_json(p, outlier=True)
+            event_from_pdu_json(p, outlier=True)
             for p in result.get("auth_chain", [])
         ]
 
@@ -441,7 +444,7 @@ class FederationClient(FederationBase):
         )
 
         auth_chain = [
-            self.event_from_pdu_json(p, outlier=True)
+            event_from_pdu_json(p, outlier=True)
             for p in res["auth_chain"]
         ]
 
@@ -570,12 +573,12 @@ class FederationClient(FederationBase):
                 logger.debug("Got content: %s", content)
 
                 state = [
-                    self.event_from_pdu_json(p, outlier=True)
+                    event_from_pdu_json(p, outlier=True)
                     for p in content.get("state", [])
                 ]
 
                 auth_chain = [
-                    self.event_from_pdu_json(p, outlier=True)
+                    event_from_pdu_json(p, outlier=True)
                     for p in content.get("auth_chain", [])
                 ]
 
@@ -650,7 +653,7 @@ class FederationClient(FederationBase):
 
         logger.debug("Got response to send_invite: %s", pdu_dict)
 
-        pdu = self.event_from_pdu_json(pdu_dict)
+        pdu = event_from_pdu_json(pdu_dict)
 
         # Check signatures are correct.
         pdu = yield self._check_sigs_and_hash(pdu)
@@ -740,7 +743,7 @@ class FederationClient(FederationBase):
         )
 
         auth_chain = [
-            self.event_from_pdu_json(e)
+            event_from_pdu_json(e)
             for e in content["auth_chain"]
         ]
 
@@ -788,7 +791,7 @@ class FederationClient(FederationBase):
             )
 
             events = [
-                self.event_from_pdu_json(e)
+                event_from_pdu_json(e)
                 for e in content.get("events", [])
             ]
 
@@ -805,15 +808,6 @@ class FederationClient(FederationBase):
 
         defer.returnValue(signed_events)
 
-    def event_from_pdu_json(self, pdu_json, outlier=False):
-        event = FrozenEvent(
-            pdu_json
-        )
-
-        event.internal_metadata.outlier = outlier
-
-        return event
-
     @defer.inlineCallbacks
     def forward_third_party_invite(self, destinations, room_id, event_dict):
         for destination in destinations:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index a2327f24b6..9849953c9b 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -12,25 +12,24 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from twisted.internet import defer
+import logging
 
-from .federation_base import FederationBase
-from .units import Transaction, Edu
+import simplejson as json
+from twisted.internet import defer
 
+from synapse.api.errors import AuthError, FederationError, SynapseError
+from synapse.crypto.event_signing import compute_event_signature
+from synapse.federation.federation_base import (
+    FederationBase,
+    event_from_pdu_json,
+)
+from synapse.federation.units import Edu, Transaction
+import synapse.metrics
+from synapse.types import get_domain_from_id
 from synapse.util import async
+from synapse.util.caches.response_cache import ResponseCache
 from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
 from synapse.util.logutils import log_function
-from synapse.util.caches.response_cache import ResponseCache
-from synapse.events import FrozenEvent
-from synapse.types import get_domain_from_id
-import synapse.metrics
-
-from synapse.api.errors import AuthError, FederationError, SynapseError
-
-from synapse.crypto.event_signing import compute_event_signature
-
-import simplejson as json
-import logging
 
 # when processing incoming transactions, we try to handle multiple rooms in
 # parallel, up to this limit.
@@ -172,7 +171,7 @@ class FederationServer(FederationBase):
                 p["age_ts"] = request_time - int(p["age"])
                 del p["age"]
 
-            event = self.event_from_pdu_json(p)
+            event = event_from_pdu_json(p)
             room_id = event.room_id
             pdus_by_room.setdefault(room_id, []).append(event)
 
@@ -346,7 +345,7 @@ class FederationServer(FederationBase):
 
     @defer.inlineCallbacks
     def on_invite_request(self, origin, content):
-        pdu = self.event_from_pdu_json(content)
+        pdu = event_from_pdu_json(content)
         ret_pdu = yield self.handler.on_invite_request(origin, pdu)
         time_now = self._clock.time_msec()
         defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@@ -354,7 +353,7 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     def on_send_join_request(self, origin, content):
         logger.debug("on_send_join_request: content: %s", content)
-        pdu = self.event_from_pdu_json(content)
+        pdu = event_from_pdu_json(content)
         logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
         res_pdus = yield self.handler.on_send_join_request(origin, pdu)
         time_now = self._clock.time_msec()
@@ -374,7 +373,7 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     def on_send_leave_request(self, origin, content):
         logger.debug("on_send_leave_request: content: %s", content)
-        pdu = self.event_from_pdu_json(content)
+        pdu = event_from_pdu_json(content)
         logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
         yield self.handler.on_send_leave_request(origin, pdu)
         defer.returnValue((200, {}))
@@ -411,7 +410,7 @@ class FederationServer(FederationBase):
         """
         with (yield self._server_linearizer.queue((origin, room_id))):
             auth_chain = [
-                self.event_from_pdu_json(e)
+                event_from_pdu_json(e)
                 for e in content["auth_chain"]
             ]
 
@@ -586,15 +585,6 @@ class FederationServer(FederationBase):
     def __str__(self):
         return "<ReplicationLayer(%s)>" % self.server_name
 
-    def event_from_pdu_json(self, pdu_json, outlier=False):
-        event = FrozenEvent(
-            pdu_json
-        )
-
-        event.internal_metadata.outlier = outlier
-
-        return event
-
     @defer.inlineCallbacks
     def exchange_third_party_invite(
             self,
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 7a3c9cbb70..a141ec9953 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -19,8 +19,8 @@ from twisted.internet import defer
 from .persistence import TransactionActions
 from .units import Transaction, Edu
 
-from synapse.api.errors import HttpResponseException
-from synapse.util import logcontext
+from synapse.api.errors import HttpResponseException, FederationDeniedError
+from synapse.util import logcontext, PreserveLoggingContext
 from synapse.util.async import run_on_reactor
 from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
 from synapse.util.metrics import measure_func
@@ -42,6 +42,8 @@ sent_edus_counter = client_metrics.register_counter("sent_edus")
 
 sent_transactions_counter = client_metrics.register_counter("sent_transactions")
 
+events_processed_counter = client_metrics.register_counter("events_processed")
+
 
 class TransactionQueue(object):
     """This class makes sure we only have one transaction in flight at
@@ -146,7 +148,6 @@ class TransactionQueue(object):
         else:
             return not destination.startswith("localhost")
 
-    @defer.inlineCallbacks
     def notify_new_events(self, current_id):
         """This gets called when we have some new events we might want to
         send out to other servers.
@@ -156,6 +157,13 @@ class TransactionQueue(object):
         if self._is_processing:
             return
 
+        # fire off a processing loop in the background. It's likely it will
+        # outlast the current request, so run it in the sentinel logcontext.
+        with PreserveLoggingContext():
+            self._process_event_queue_loop()
+
+    @defer.inlineCallbacks
+    def _process_event_queue_loop(self):
         try:
             self._is_processing = True
             while True:
@@ -199,6 +207,8 @@ class TransactionQueue(object):
 
                     self._send_pdu(event, destinations)
 
+                events_processed_counter.inc_by(len(events))
+
                 yield self.store.update_federation_out_pos(
                     "events", next_token
                 )
@@ -480,6 +490,8 @@ class TransactionQueue(object):
                     (e.retry_last_ts + e.retry_interval) / 1000.0
                 ),
             )
+        except FederationDeniedError as e:
+            logger.info(e)
         except Exception as e:
             logger.warn(
                 "TX [%s] Failed to send transaction: %s",
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 1f3ce238f6..5488e82985 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -212,6 +212,9 @@ class TransportLayerClient(object):
 
             Fails with ``NotRetryingDestination`` if we are not yet ready
             to retry this server.
+
+            Fails with ``FederationDeniedError`` if the remote destination
+            is not in our federation whitelist
         """
         valid_memberships = {Membership.JOIN, Membership.LEAVE}
         if membership not in valid_memberships:
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 2b02b021ec..06c16ba4fa 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -16,7 +16,7 @@
 from twisted.internet import defer
 
 from synapse.api.urls import FEDERATION_PREFIX as PREFIX
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import Codes, SynapseError, FederationDeniedError
 from synapse.http.server import JsonResource
 from synapse.http.servlet import (
     parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
@@ -81,6 +81,7 @@ class Authenticator(object):
         self.keyring = hs.get_keyring()
         self.server_name = hs.hostname
         self.store = hs.get_datastore()
+        self.federation_domain_whitelist = hs.config.federation_domain_whitelist
 
     # A method just so we can pass 'self' as the authenticator to the Servlets
     @defer.inlineCallbacks
@@ -92,6 +93,12 @@ class Authenticator(object):
             "signatures": {},
         }
 
+        if (
+            self.federation_domain_whitelist is not None and
+            self.server_name not in self.federation_domain_whitelist
+        ):
+            raise FederationDeniedError(self.server_name)
+
         if content is not None:
             json_request["content"] = content
 
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index feca3e4c10..3dd3fa2a27 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -15,6 +15,7 @@
 
 from twisted.internet import defer
 
+import synapse
 from synapse.api.constants import EventTypes
 from synapse.util.metrics import Measure
 from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
@@ -23,6 +24,10 @@ import logging
 
 logger = logging.getLogger(__name__)
 
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+events_processed_counter = metrics.register_counter("events_processed")
+
 
 def log_failure(failure):
     logger.error(
@@ -103,6 +108,8 @@ class ApplicationServicesHandler(object):
                                 service, event
                             )
 
+                    events_processed_counter.inc_by(len(events))
+
                     yield self.store.set_appservice_last_pos(upper_bound)
             finally:
                 self.is_processing = False
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 080eb14271..258cc345dc 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -13,15 +13,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from twisted.internet import defer
+from twisted.internet import defer, threads
 
 from ._base import BaseHandler
 from synapse.api.constants import LoginType
-from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
+from synapse.api.errors import (
+    AuthError, Codes, InteractiveAuthIncompleteError, LoginError, StoreError,
+    SynapseError,
+)
 from synapse.module_api import ModuleApi
 from synapse.types import UserID
 from synapse.util.async import run_on_reactor
 from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.logcontext import make_deferred_yieldable
 
 from twisted.web.client import PartialDownloadError
 
@@ -46,7 +50,6 @@ class AuthHandler(BaseHandler):
         """
         super(AuthHandler, self).__init__(hs)
         self.checkers = {
-            LoginType.PASSWORD: self._check_password_auth,
             LoginType.RECAPTCHA: self._check_recaptcha,
             LoginType.EMAIL_IDENTITY: self._check_email_identity,
             LoginType.MSISDN: self._check_msisdn,
@@ -75,15 +78,76 @@ class AuthHandler(BaseHandler):
         self.macaroon_gen = hs.get_macaroon_generator()
         self._password_enabled = hs.config.password_enabled
 
-        login_types = set()
+        # we keep this as a list despite the O(N^2) implication so that we can
+        # keep PASSWORD first and avoid confusing clients which pick the first
+        # type in the list. (NB that the spec doesn't require us to do so and
+        # clients which favour types that they don't understand over those that
+        # they do are technically broken)
+        login_types = []
         if self._password_enabled:
-            login_types.add(LoginType.PASSWORD)
+            login_types.append(LoginType.PASSWORD)
         for provider in self.password_providers:
             if hasattr(provider, "get_supported_login_types"):
-                login_types.update(
-                    provider.get_supported_login_types().keys()
-                )
-        self._supported_login_types = frozenset(login_types)
+                for t in provider.get_supported_login_types().keys():
+                    if t not in login_types:
+                        login_types.append(t)
+        self._supported_login_types = login_types
+
+    @defer.inlineCallbacks
+    def validate_user_via_ui_auth(self, requester, request_body, clientip):
+        """
+        Checks that the user is who they claim to be, via a UI auth.
+
+        This is used for things like device deletion and password reset where
+        the user already has a valid access token, but we want to double-check
+        that it isn't stolen by re-authenticating them.
+
+        Args:
+            requester (Requester): The user, as given by the access token
+
+            request_body (dict): The body of the request sent by the client
+
+            clientip (str): The IP address of the client.
+
+        Returns:
+            defer.Deferred[dict]: the parameters for this request (which may
+                have been given only in a previous call).
+
+        Raises:
+            InteractiveAuthIncompleteError if the client has not yet completed
+                any of the permitted login flows
+
+            AuthError if the client has completed a login flow, and it gives
+                a different user to `requester`
+        """
+
+        # build a list of supported flows
+        flows = [
+            [login_type] for login_type in self._supported_login_types
+        ]
+
+        result, params, _ = yield self.check_auth(
+            flows, request_body, clientip,
+        )
+
+        # find the completed login type
+        for login_type in self._supported_login_types:
+            if login_type not in result:
+                continue
+
+            user_id = result[login_type]
+            break
+        else:
+            # this can't happen
+            raise Exception(
+                "check_auth returned True but no successful login type",
+            )
+
+        # check that the UI auth matched the access token
+        if user_id != requester.user.to_string():
+            raise AuthError(403, "Invalid auth")
+
+        defer.returnValue(params)
 
     @defer.inlineCallbacks
     def check_auth(self, flows, clientdict, clientip):
@@ -95,26 +159,36 @@ class AuthHandler(BaseHandler):
         session with a map, which maps each auth-type (str) to the relevant
         identity authenticated by that auth-type (mostly str, but for captcha, bool).
 
+        If no auth flows have been completed successfully, raises an
+        InteractiveAuthIncompleteError. To handle this, you can use
+        synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
+        decorator.
+
         Args:
             flows (list): A list of login flows. Each flow is an ordered list of
                           strings representing auth-types. At least one full
                           flow must be completed in order for auth to be successful.
+
             clientdict: The dictionary from the client root level, not the
                         'auth' key: this method prompts for auth if none is sent.
+
             clientip (str): The IP address of the client.
+
         Returns:
-            A tuple of (authed, dict, dict, session_id) where authed is true if
-            the client has successfully completed an auth flow. If it is true
-            the first dict contains the authenticated credentials of each stage.
+            defer.Deferred[dict, dict, str]: a deferred tuple of
+                (creds, params, session_id).
+
+                'creds' contains the authenticated credentials of each stage.
 
-            If authed is false, the first dictionary is the server response to
-            the login request and should be passed back to the client.
+                'params' contains the parameters for this request (which may
+                have been given only in a previous call).
 
-            In either case, the second dict contains the parameters for this
-            request (which may have been given only in a previous call).
+                'session_id' is the ID of this session, either passed in by the
+                client or assigned by this call
 
-            session_id is the ID of this session, either passed in by the client
-            or assigned by the call to check_auth
+        Raises:
+            InteractiveAuthIncompleteError if the client has not yet completed
+                all the stages in any of the permitted flows.
         """
 
         authdict = None
@@ -142,11 +216,8 @@ class AuthHandler(BaseHandler):
             clientdict = session['clientdict']
 
         if not authdict:
-            defer.returnValue(
-                (
-                    False, self._auth_dict_for_flows(flows, session),
-                    clientdict, session['id']
-                )
+            raise InteractiveAuthIncompleteError(
+                self._auth_dict_for_flows(flows, session),
             )
 
         if 'creds' not in session:
@@ -157,14 +228,12 @@ class AuthHandler(BaseHandler):
         errordict = {}
         if 'type' in authdict:
             login_type = authdict['type']
-            if login_type not in self.checkers:
-                raise LoginError(400, "", Codes.UNRECOGNIZED)
             try:
-                result = yield self.checkers[login_type](authdict, clientip)
+                result = yield self._check_auth_dict(authdict, clientip)
                 if result:
                     creds[login_type] = result
                     self._save_session(session)
-            except LoginError, e:
+            except LoginError as e:
                 if login_type == LoginType.EMAIL_IDENTITY:
                     # riot used to have a bug where it would request a new
                     # validation token (thus sending a new email) each time it
@@ -173,7 +242,7 @@ class AuthHandler(BaseHandler):
                     #
                     # Grandfather in the old behaviour for now to avoid
                     # breaking old riot deployments.
-                    raise e
+                    raise
 
                 # this step failed. Merge the error dict into the response
                 # so that the client can have another go.
@@ -190,12 +259,14 @@ class AuthHandler(BaseHandler):
                     "Auth completed with creds: %r. Client dict has keys: %r",
                     creds, clientdict.keys()
                 )
-                defer.returnValue((True, creds, clientdict, session['id']))
+                defer.returnValue((creds, clientdict, session['id']))
 
         ret = self._auth_dict_for_flows(flows, session)
         ret['completed'] = creds.keys()
         ret.update(errordict)
-        defer.returnValue((False, ret, clientdict, session['id']))
+        raise InteractiveAuthIncompleteError(
+            ret,
+        )
 
     @defer.inlineCallbacks
     def add_oob_auth(self, stagetype, authdict, clientip):
@@ -268,17 +339,35 @@ class AuthHandler(BaseHandler):
         return sess.setdefault('serverdict', {}).get(key, default)
 
     @defer.inlineCallbacks
-    def _check_password_auth(self, authdict, _):
-        if "user" not in authdict or "password" not in authdict:
-            raise LoginError(400, "", Codes.MISSING_PARAM)
+    def _check_auth_dict(self, authdict, clientip):
+        """Attempt to validate the auth dict provided by a client
 
-        user_id = authdict["user"]
-        password = authdict["password"]
+        Args:
+            authdict (object): auth dict provided by the client
+            clientip (str): IP address of the client
+
+        Returns:
+            Deferred: result of the stage verification.
+
+        Raises:
+            StoreError if there was a problem accessing the database
+            SynapseError if there was a problem with the request
+            LoginError if there was an authentication problem.
+        """
+        login_type = authdict['type']
+        checker = self.checkers.get(login_type)
+        if checker is not None:
+            res = yield checker(authdict, clientip)
+            defer.returnValue(res)
 
-        (canonical_id, callback) = yield self.validate_login(user_id, {
-            "type": LoginType.PASSWORD,
-            "password": password,
-        })
+        # build a v1-login-style dict out of the authdict and fall back to the
+        # v1 code
+        user_id = authdict.get("user")
+
+        if user_id is None:
+            raise SynapseError(400, "", Codes.MISSING_PARAM)
+
+        (canonical_id, callback) = yield self.validate_login(user_id, authdict)
         defer.returnValue(canonical_id)
 
     @defer.inlineCallbacks
@@ -626,7 +715,7 @@ class AuthHandler(BaseHandler):
         if not lookupres:
             defer.returnValue(None)
         (user_id, password_hash) = lookupres
-        result = self.validate_hash(password, password_hash)
+        result = yield self.validate_hash(password, password_hash)
         if not result:
             logger.warn("Failed password login for user %s", user_id)
             defer.returnValue(None)
@@ -650,41 +739,6 @@ class AuthHandler(BaseHandler):
             raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
 
     @defer.inlineCallbacks
-    def set_password(self, user_id, newpassword, requester=None):
-        password_hash = self.hash(newpassword)
-
-        except_access_token_id = requester.access_token_id if requester else None
-
-        try:
-            yield self.store.user_set_password_hash(user_id, password_hash)
-        except StoreError as e:
-            if e.code == 404:
-                raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
-            raise e
-        yield self.delete_access_tokens_for_user(
-            user_id, except_token_id=except_access_token_id,
-        )
-        yield self.hs.get_pusherpool().remove_pushers_by_user(
-            user_id, except_access_token_id
-        )
-
-    @defer.inlineCallbacks
-    def deactivate_account(self, user_id):
-        """Deactivate a user's account
-
-        Args:
-            user_id (str): ID of user to be deactivated
-
-        Returns:
-            Deferred
-        """
-        # FIXME: Theoretically there is a race here wherein user resets
-        # password using threepid.
-        yield self.delete_access_tokens_for_user(user_id)
-        yield self.store.user_delete_threepids(user_id)
-        yield self.store.user_set_password_hash(user_id, None)
-
-    @defer.inlineCallbacks
     def delete_access_token(self, access_token):
         """Invalidate a single access token
 
@@ -706,6 +760,12 @@ class AuthHandler(BaseHandler):
                     access_token=access_token,
                 )
 
+        # delete pushers associated with this access token
+        if user_info["token_id"] is not None:
+            yield self.hs.get_pusherpool().remove_pushers_by_access_token(
+                str(user_info["user"]), (user_info["token_id"], )
+            )
+
     @defer.inlineCallbacks
     def delete_access_tokens_for_user(self, user_id, except_token_id=None,
                                       device_id=None):
@@ -728,13 +788,18 @@ class AuthHandler(BaseHandler):
         # see if any of our auth providers want to know about this
         for provider in self.password_providers:
             if hasattr(provider, "on_logged_out"):
-                for token, device_id in tokens_and_devices:
+                for token, token_id, device_id in tokens_and_devices:
                     yield provider.on_logged_out(
                         user_id=user_id,
                         device_id=device_id,
                         access_token=token,
                     )
 
+        # delete pushers associated with the access tokens
+        yield self.hs.get_pusherpool().remove_pushers_by_access_token(
+            user_id, (token_id for _, token_id, _ in tokens_and_devices),
+        )
+
     @defer.inlineCallbacks
     def add_threepid(self, user_id, medium, address, validated_at):
         # 'Canonicalise' email addresses down to lower case.
@@ -778,10 +843,13 @@ class AuthHandler(BaseHandler):
             password (str): Password to hash.
 
         Returns:
-            Hashed password (str).
+            Deferred(str): Hashed password.
         """
-        return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
-                             bcrypt.gensalt(self.bcrypt_rounds))
+        def _do_hash():
+            return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
+                                 bcrypt.gensalt(self.bcrypt_rounds))
+
+        return make_deferred_yieldable(threads.deferToThread(_do_hash))
 
     def validate_hash(self, password, stored_hash):
         """Validates that self.hash(password) == stored_hash.
@@ -791,13 +859,17 @@ class AuthHandler(BaseHandler):
             stored_hash (str): Expected hash value.
 
         Returns:
-            Whether self.hash(password) == stored_hash (bool).
+            Deferred(bool): Whether self.hash(password) == stored_hash.
         """
-        if stored_hash:
+
+        def _do_validate_hash():
             return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
                                  stored_hash.encode('utf8')) == stored_hash
+
+        if stored_hash:
+            return make_deferred_yieldable(threads.deferToThread(_do_validate_hash))
         else:
-            return False
+            return defer.succeed(False)
 
 
 class MacaroonGeneartor(object):
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
new file mode 100644
index 0000000000..b1d3814909
--- /dev/null
+++ b/synapse/handlers/deactivate_account.py
@@ -0,0 +1,52 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from ._base import BaseHandler
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class DeactivateAccountHandler(BaseHandler):
+    """Handler which deals with deactivating user accounts."""
+    def __init__(self, hs):
+        super(DeactivateAccountHandler, self).__init__(hs)
+        self._auth_handler = hs.get_auth_handler()
+        self._device_handler = hs.get_device_handler()
+
+    @defer.inlineCallbacks
+    def deactivate_account(self, user_id):
+        """Deactivate a user's account
+
+        Args:
+            user_id (str): ID of user to be deactivated
+
+        Returns:
+            Deferred
+        """
+        # FIXME: Theoretically there is a race here wherein user resets
+        # password using threepid.
+
+        # first delete any devices belonging to the user, which will also
+        # delete corresponding access tokens.
+        yield self._device_handler.delete_all_devices_for_user(user_id)
+        # then delete any remaining access tokens which weren't associated with
+        # a device.
+        yield self._auth_handler.delete_access_tokens_for_user(user_id)
+
+        yield self.store.user_delete_threepids(user_id)
+        yield self.store.user_set_password_hash(user_id, None)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 579d8477ba..0e83453851 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 from synapse.api import errors
 from synapse.api.constants import EventTypes
+from synapse.api.errors import FederationDeniedError
 from synapse.util import stringutils
 from synapse.util.async import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -171,12 +172,30 @@ class DeviceHandler(BaseHandler):
         yield self.notify_device_update(user_id, [device_id])
 
     @defer.inlineCallbacks
+    def delete_all_devices_for_user(self, user_id, except_device_id=None):
+        """Delete all of the user's devices
+
+        Args:
+            user_id (str):
+            except_device_id (str|None): optional device id which should not
+                be deleted
+
+        Returns:
+            defer.Deferred:
+        """
+        device_map = yield self.store.get_devices_by_user(user_id)
+        device_ids = device_map.keys()
+        if except_device_id is not None:
+            device_ids = [d for d in device_ids if d != except_device_id]
+        yield self.delete_devices(user_id, device_ids)
+
+    @defer.inlineCallbacks
     def delete_devices(self, user_id, device_ids):
         """ Delete several devices
 
         Args:
             user_id (str):
-            device_ids (str): The list of device IDs to delete
+            device_ids (List[str]): The list of device IDs to delete
 
         Returns:
             defer.Deferred:
@@ -495,6 +514,9 @@ class DeviceListEduUpdater(object):
                     # This makes it more likely that the device lists will
                     # eventually become consistent.
                     return
+                except FederationDeniedError as e:
+                    logger.info(e)
+                    return
                 except Exception:
                     # TODO: Remember that we are now out of sync and try again
                     # later
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index f7fad15c62..d996aa90bb 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -17,7 +17,8 @@ import logging
 
 from twisted.internet import defer
 
-from synapse.types import get_domain_from_id
+from synapse.api.errors import SynapseError
+from synapse.types import get_domain_from_id, UserID
 from synapse.util.stringutils import random_string
 
 
@@ -33,7 +34,7 @@ class DeviceMessageHandler(object):
         """
         self.store = hs.get_datastore()
         self.notifier = hs.get_notifier()
-        self.is_mine_id = hs.is_mine_id
+        self.is_mine = hs.is_mine
         self.federation = hs.get_federation_sender()
 
         hs.get_replication_layer().register_edu_handler(
@@ -52,6 +53,12 @@ class DeviceMessageHandler(object):
         message_type = content["type"]
         message_id = content["message_id"]
         for user_id, by_device in content["messages"].items():
+            # we use UserID.from_string to catch invalid user ids
+            if not self.is_mine(UserID.from_string(user_id)):
+                logger.warning("Request for keys for non-local user %s",
+                               user_id)
+                raise SynapseError(400, "Not a user here")
+
             messages_by_device = {
                 device_id: {
                     "content": message_content,
@@ -77,7 +84,8 @@ class DeviceMessageHandler(object):
         local_messages = {}
         remote_messages = {}
         for user_id, by_device in messages.items():
-            if self.is_mine_id(user_id):
+            # we use UserID.from_string to catch invalid user ids
+            if self.is_mine(UserID.from_string(user_id)):
                 messages_by_device = {
                     device_id: {
                         "content": message_content,
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index a0464ae5c0..8580ada60a 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -34,6 +34,7 @@ class DirectoryHandler(BaseHandler):
 
         self.state = hs.get_state_handler()
         self.appservice_handler = hs.get_application_service_handler()
+        self.event_creation_handler = hs.get_event_creation_handler()
 
         self.federation = hs.get_replication_layer()
         self.federation.register_query_handler(
@@ -249,8 +250,7 @@ class DirectoryHandler(BaseHandler):
     def send_room_alias_update_event(self, requester, user_id, room_id):
         aliases = yield self.store.get_aliases_for_room(room_id)
 
-        msg_handler = self.hs.get_handlers().message_handler
-        yield msg_handler.create_and_send_nonmember_event(
+        yield self.event_creation_handler.create_and_send_nonmember_event(
             requester,
             {
                 "type": EventTypes.Aliases,
@@ -272,8 +272,7 @@ class DirectoryHandler(BaseHandler):
         if not alias_event or alias_event.content.get("alias", "") != alias_str:
             return
 
-        msg_handler = self.hs.get_handlers().message_handler
-        yield msg_handler.create_and_send_nonmember_event(
+        yield self.event_creation_handler.create_and_send_nonmember_event(
             requester,
             {
                 "type": EventTypes.CanonicalAlias,
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 668a90e495..9aa95f89e6 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -19,8 +19,10 @@ import logging
 from canonicaljson import encode_canonical_json
 from twisted.internet import defer
 
-from synapse.api.errors import SynapseError, CodeMessageException
-from synapse.types import get_domain_from_id
+from synapse.api.errors import (
+    SynapseError, CodeMessageException, FederationDeniedError,
+)
+from synapse.types import get_domain_from_id, UserID
 from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
 from synapse.util.retryutils import NotRetryingDestination
 
@@ -32,7 +34,7 @@ class E2eKeysHandler(object):
         self.store = hs.get_datastore()
         self.federation = hs.get_replication_layer()
         self.device_handler = hs.get_device_handler()
-        self.is_mine_id = hs.is_mine_id
+        self.is_mine = hs.is_mine
         self.clock = hs.get_clock()
 
         # doesn't really work as part of the generic query API, because the
@@ -70,7 +72,8 @@ class E2eKeysHandler(object):
         remote_queries = {}
 
         for user_id, device_ids in device_keys_query.items():
-            if self.is_mine_id(user_id):
+            # we use UserID.from_string to catch invalid user ids
+            if self.is_mine(UserID.from_string(user_id)):
                 local_query[user_id] = device_ids
             else:
                 remote_queries[user_id] = device_ids
@@ -139,6 +142,10 @@ class E2eKeysHandler(object):
                 failures[destination] = {
                     "status": 503, "message": "Not ready for retry",
                 }
+            except FederationDeniedError as e:
+                failures[destination] = {
+                    "status": 403, "message": "Federation Denied",
+                }
             except Exception as e:
                 # include ConnectionRefused and other errors
                 failures[destination] = {
@@ -170,7 +177,8 @@ class E2eKeysHandler(object):
 
         result_dict = {}
         for user_id, device_ids in query.items():
-            if not self.is_mine_id(user_id):
+            # we use UserID.from_string to catch invalid user ids
+            if not self.is_mine(UserID.from_string(user_id)):
                 logger.warning("Request for keys for non-local user %s",
                                user_id)
                 raise SynapseError(400, "Not a user here")
@@ -213,7 +221,8 @@ class E2eKeysHandler(object):
         remote_queries = {}
 
         for user_id, device_keys in query.get("one_time_keys", {}).items():
-            if self.is_mine_id(user_id):
+            # we use UserID.from_string to catch invalid user ids
+            if self.is_mine(UserID.from_string(user_id)):
                 for device_id, algorithm in device_keys.items():
                     local_query.append((user_id, device_id, algorithm))
             else:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ac70730885..46bcf8b081 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -22,6 +23,7 @@ from ._base import BaseHandler
 
 from synapse.api.errors import (
     AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
+    FederationDeniedError,
 )
 from synapse.api.constants import EventTypes, Membership, RejectedReason
 from synapse.events.validator import EventValidator
@@ -74,6 +76,7 @@ class FederationHandler(BaseHandler):
         self.is_mine_id = hs.is_mine_id
         self.pusher_pool = hs.get_pusherpool()
         self.spam_checker = hs.get_spam_checker()
+        self.event_creation_handler = hs.get_event_creation_handler()
 
         self.replication_layer.set_handler(self)
 
@@ -782,6 +785,9 @@ class FederationHandler(BaseHandler):
                 except NotRetryingDestination as e:
                     logger.info(e.message)
                     continue
+                except FederationDeniedError as e:
+                    logger.info(e)
+                    continue
                 except Exception as e:
                     logger.exception(
                         "Failed to backfill from %s because %s",
@@ -804,13 +810,12 @@ class FederationHandler(BaseHandler):
         event_ids = list(extremities.keys())
 
         logger.debug("calling resolve_state_groups in _maybe_backfill")
+        resolve = logcontext.preserve_fn(
+            self.state_handler.resolve_state_groups_for_events
+        )
         states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
-            [
-                logcontext.preserve_fn(self.state_handler.resolve_state_groups)(
-                    room_id, [e]
-                )
-                for e in event_ids
-            ], consumeErrors=True,
+            [resolve(room_id, [e]) for e in event_ids],
+            consumeErrors=True,
         ))
         states = dict(zip(event_ids, [s.state for s in states]))
 
@@ -1004,8 +1009,7 @@ class FederationHandler(BaseHandler):
         })
 
         try:
-            message_handler = self.hs.get_handlers().message_handler
-            event, context = yield message_handler._create_new_client_event(
+            event, context = yield self.event_creation_handler.create_new_client_event(
                 builder=builder,
             )
         except AuthError as e:
@@ -1245,8 +1249,7 @@ class FederationHandler(BaseHandler):
             "state_key": user_id,
         })
 
-        message_handler = self.hs.get_handlers().message_handler
-        event, context = yield message_handler._create_new_client_event(
+        event, context = yield self.event_creation_handler.create_new_client_event(
             builder=builder,
         )
 
@@ -1828,8 +1831,8 @@ class FederationHandler(BaseHandler):
                 current_state = set(e.event_id for e in auth_events.values())
                 different_auth = event_auth_events - current_state
 
-                self._update_context_for_auth_events(
-                    context, auth_events, event_key,
+                yield self._update_context_for_auth_events(
+                    event, context, auth_events, event_key,
                 )
 
         if different_auth and not event.internal_metadata.is_outlier():
@@ -1910,8 +1913,8 @@ class FederationHandler(BaseHandler):
                 # 4. Look at rejects and their proofs.
                 # TODO.
 
-                self._update_context_for_auth_events(
-                    context, auth_events, event_key,
+                yield self._update_context_for_auth_events(
+                    event, context, auth_events, event_key,
                 )
 
         try:
@@ -1920,11 +1923,15 @@ class FederationHandler(BaseHandler):
             logger.warn("Failed auth resolution for %r because %s", event, e)
             raise e
 
-    def _update_context_for_auth_events(self, context, auth_events,
+    @defer.inlineCallbacks
+    def _update_context_for_auth_events(self, event, context, auth_events,
                                         event_key):
-        """Update the state_ids in an event context after auth event resolution
+        """Update the state_ids in an event context after auth event resolution,
+        storing the changes as a new state group.
 
         Args:
+            event (Event): The event we're handling the context for
+
             context (synapse.events.snapshot.EventContext): event context
                 to be updated
 
@@ -1947,7 +1954,13 @@ class FederationHandler(BaseHandler):
         context.prev_state_ids.update({
             k: a.event_id for k, a in auth_events.iteritems()
         })
-        context.state_group = self.store.get_next_state_group()
+        context.state_group = yield self.store.store_state_group(
+            event.event_id,
+            event.room_id,
+            prev_group=context.prev_group,
+            delta_ids=context.delta_ids,
+            current_state_ids=context.current_state_ids,
+        )
 
     @defer.inlineCallbacks
     def construct_auth_difference(self, local_auth, remote_auth):
@@ -2117,8 +2130,7 @@ class FederationHandler(BaseHandler):
         if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
             builder = self.event_builder_factory.new(event_dict)
             EventValidator().validate_new(builder)
-            message_handler = self.hs.get_handlers().message_handler
-            event, context = yield message_handler._create_new_client_event(
+            event, context = yield self.event_creation_handler.create_new_client_event(
                 builder=builder
             )
 
@@ -2156,8 +2168,7 @@ class FederationHandler(BaseHandler):
         """
         builder = self.event_builder_factory.new(event_dict)
 
-        message_handler = self.hs.get_handlers().message_handler
-        event, context = yield message_handler._create_new_client_event(
+        event, context = yield self.event_creation_handler.create_new_client_event(
             builder=builder,
         )
 
@@ -2207,8 +2218,9 @@ class FederationHandler(BaseHandler):
 
         builder = self.event_builder_factory.new(event_dict)
         EventValidator().validate_new(builder)
-        message_handler = self.hs.get_handlers().message_handler
-        event, context = yield message_handler._create_new_client_event(builder=builder)
+        event, context = yield self.event_creation_handler.create_new_client_event(
+            builder=builder,
+        )
         defer.returnValue((event, context))
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 7e5d3f148d..e4d0cc8b02 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -383,11 +383,12 @@ class GroupsLocalHandler(object):
 
             defer.returnValue({"groups": result})
         else:
-            result = yield self.transport_client.get_publicised_groups_for_user(
-                get_domain_from_id(user_id), user_id
+            bulk_result = yield self.transport_client.bulk_get_publicised_groups(
+                get_domain_from_id(user_id), [user_id],
             )
+            result = bulk_result.get("users", {}).get(user_id)
             # TODO: Verify attestations
-            defer.returnValue(result)
+            defer.returnValue({"groups": result})
 
     @defer.inlineCallbacks
     def bulk_get_publicised_groups(self, user_ids, proxy=True):
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 21f1717dd2..4e9752ccbd 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014 - 2016 OpenMarket Ltd
-# Copyright 2017 New Vector Ltd
+# Copyright 2017 - 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -47,23 +47,11 @@ class MessageHandler(BaseHandler):
         self.hs = hs
         self.state = hs.get_state_handler()
         self.clock = hs.get_clock()
-        self.validator = EventValidator()
-        self.profile_handler = hs.get_profile_handler()
 
         self.pagination_lock = ReadWriteLock()
 
-        self.pusher_pool = hs.get_pusherpool()
-
-        # We arbitrarily limit concurrent event creation for a room to 5.
-        # This is to stop us from diverging history *too* much.
-        self.limiter = Limiter(max_count=5)
-
-        self.action_generator = hs.get_action_generator()
-
-        self.spam_checker = hs.get_spam_checker()
-
     @defer.inlineCallbacks
-    def purge_history(self, room_id, event_id):
+    def purge_history(self, room_id, event_id, delete_local_events=False):
         event = yield self.store.get_event(event_id)
 
         if event.room_id != room_id:
@@ -72,7 +60,7 @@ class MessageHandler(BaseHandler):
         depth = event.depth
 
         with (yield self.pagination_lock.write(room_id)):
-            yield self.store.delete_old_state(room_id, depth)
+            yield self.store.purge_history(room_id, depth, delete_local_events)
 
     @defer.inlineCallbacks
     def get_messages(self, requester, room_id=None, pagin_config=None,
@@ -183,6 +171,162 @@ class MessageHandler(BaseHandler):
         defer.returnValue(chunk)
 
     @defer.inlineCallbacks
+    def get_room_data(self, user_id=None, room_id=None,
+                      event_type=None, state_key="", is_guest=False):
+        """ Get data from a room.
+
+        Args:
+            event : The room path event
+        Returns:
+            The path data content.
+        Raises:
+            SynapseError if something went wrong.
+        """
+        membership, membership_event_id = yield self._check_in_room_or_world_readable(
+            room_id, user_id
+        )
+
+        if membership == Membership.JOIN:
+            data = yield self.state_handler.get_current_state(
+                room_id, event_type, state_key
+            )
+        elif membership == Membership.LEAVE:
+            key = (event_type, state_key)
+            room_state = yield self.store.get_state_for_events(
+                [membership_event_id], [key]
+            )
+            data = room_state[membership_event_id].get(key)
+
+        defer.returnValue(data)
+
+    @defer.inlineCallbacks
+    def _check_in_room_or_world_readable(self, room_id, user_id):
+        try:
+            # check_user_was_in_room will return the most recent membership
+            # event for the user if:
+            #  * The user is a non-guest user, and was ever in the room
+            #  * The user is a guest user, and has joined the room
+            # else it will throw.
+            member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
+            defer.returnValue((member_event.membership, member_event.event_id))
+            return
+        except AuthError:
+            visibility = yield self.state_handler.get_current_state(
+                room_id, EventTypes.RoomHistoryVisibility, ""
+            )
+            if (
+                visibility and
+                visibility.content["history_visibility"] == "world_readable"
+            ):
+                defer.returnValue((Membership.JOIN, None))
+                return
+            raise AuthError(
+                403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
+            )
+
+    @defer.inlineCallbacks
+    def get_state_events(self, user_id, room_id, is_guest=False):
+        """Retrieve all state events for a given room. If the user is
+        joined to the room then return the current state. If the user has
+        left the room return the state events from when they left.
+
+        Args:
+            user_id(str): The user requesting state events.
+            room_id(str): The room ID to get all state events from.
+        Returns:
+            A list of dicts representing state events. [{}, {}, {}]
+        """
+        membership, membership_event_id = yield self._check_in_room_or_world_readable(
+            room_id, user_id
+        )
+
+        if membership == Membership.JOIN:
+            room_state = yield self.state_handler.get_current_state(room_id)
+        elif membership == Membership.LEAVE:
+            room_state = yield self.store.get_state_for_events(
+                [membership_event_id], None
+            )
+            room_state = room_state[membership_event_id]
+
+        now = self.clock.time_msec()
+        defer.returnValue(
+            [serialize_event(c, now) for c in room_state.values()]
+        )
+
+    @defer.inlineCallbacks
+    def get_joined_members(self, requester, room_id):
+        """Get all the joined members in the room and their profile information.
+
+        If the user has left the room return the state events from when they left.
+
+        Args:
+            requester(Requester): The user requesting state events.
+            room_id(str): The room ID to get all state events from.
+        Returns:
+            A dict of user_id to profile info
+        """
+        user_id = requester.user.to_string()
+        if not requester.app_service:
+            # We check AS auth after fetching the room membership, as it
+            # requires us to pull out all joined members anyway.
+            membership, _ = yield self._check_in_room_or_world_readable(
+                room_id, user_id
+            )
+            if membership != Membership.JOIN:
+                raise NotImplementedError(
+                    "Getting joined members after leaving is not implemented"
+                )
+
+        users_with_profile = yield self.state.get_current_user_in_room(room_id)
+
+        # If this is an AS, double check that they are allowed to see the members.
+        # This can either be because the AS user is in the room or becuase there
+        # is a user in the room that the AS is "interested in"
+        if requester.app_service and user_id not in users_with_profile:
+            for uid in users_with_profile:
+                if requester.app_service.is_interested_in_user(uid):
+                    break
+            else:
+                # Loop fell through, AS has no interested users in room
+                raise AuthError(403, "Appservice not in room")
+
+        defer.returnValue({
+            user_id: {
+                "avatar_url": profile.avatar_url,
+                "display_name": profile.display_name,
+            }
+            for user_id, profile in users_with_profile.iteritems()
+        })
+
+
+class EventCreationHandler(object):
+    def __init__(self, hs):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+        self.state = hs.get_state_handler()
+        self.clock = hs.get_clock()
+        self.validator = EventValidator()
+        self.profile_handler = hs.get_profile_handler()
+        self.event_builder_factory = hs.get_event_builder_factory()
+        self.server_name = hs.hostname
+        self.ratelimiter = hs.get_ratelimiter()
+        self.notifier = hs.get_notifier()
+
+        # This is only used to get at ratelimit function, and maybe_kick_guest_users
+        self.base_handler = BaseHandler(hs)
+
+        self.pusher_pool = hs.get_pusherpool()
+
+        # We arbitrarily limit concurrent event creation for a room to 5.
+        # This is to stop us from diverging history *too* much.
+        self.limiter = Limiter(max_count=5)
+
+        self.action_generator = hs.get_action_generator()
+
+        self.spam_checker = hs.get_spam_checker()
+
+    @defer.inlineCallbacks
     def create_event(self, requester, event_dict, token_id=None, txn_id=None,
                      prev_event_ids=None):
         """
@@ -234,7 +378,7 @@ class MessageHandler(BaseHandler):
             if txn_id is not None:
                 builder.internal_metadata.txn_id = txn_id
 
-            event, context = yield self._create_new_client_event(
+            event, context = yield self.create_new_client_event(
                 builder=builder,
                 requester=requester,
                 prev_event_ids=prev_event_ids,
@@ -259,11 +403,6 @@ class MessageHandler(BaseHandler):
                 "Tried to send member event through non-member codepath"
             )
 
-        # We check here if we are currently being rate limited, so that we
-        # don't do unnecessary work. We check again just before we actually
-        # send the event.
-        yield self.ratelimit(requester, update=False)
-
         user = UserID.from_string(event.sender)
 
         assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
@@ -342,137 +481,9 @@ class MessageHandler(BaseHandler):
         )
         defer.returnValue(event)
 
+    @measure_func("create_new_client_event")
     @defer.inlineCallbacks
-    def get_room_data(self, user_id=None, room_id=None,
-                      event_type=None, state_key="", is_guest=False):
-        """ Get data from a room.
-
-        Args:
-            event : The room path event
-        Returns:
-            The path data content.
-        Raises:
-            SynapseError if something went wrong.
-        """
-        membership, membership_event_id = yield self._check_in_room_or_world_readable(
-            room_id, user_id
-        )
-
-        if membership == Membership.JOIN:
-            data = yield self.state_handler.get_current_state(
-                room_id, event_type, state_key
-            )
-        elif membership == Membership.LEAVE:
-            key = (event_type, state_key)
-            room_state = yield self.store.get_state_for_events(
-                [membership_event_id], [key]
-            )
-            data = room_state[membership_event_id].get(key)
-
-        defer.returnValue(data)
-
-    @defer.inlineCallbacks
-    def _check_in_room_or_world_readable(self, room_id, user_id):
-        try:
-            # check_user_was_in_room will return the most recent membership
-            # event for the user if:
-            #  * The user is a non-guest user, and was ever in the room
-            #  * The user is a guest user, and has joined the room
-            # else it will throw.
-            member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
-            defer.returnValue((member_event.membership, member_event.event_id))
-            return
-        except AuthError:
-            visibility = yield self.state_handler.get_current_state(
-                room_id, EventTypes.RoomHistoryVisibility, ""
-            )
-            if (
-                visibility and
-                visibility.content["history_visibility"] == "world_readable"
-            ):
-                defer.returnValue((Membership.JOIN, None))
-                return
-            raise AuthError(
-                403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
-            )
-
-    @defer.inlineCallbacks
-    def get_state_events(self, user_id, room_id, is_guest=False):
-        """Retrieve all state events for a given room. If the user is
-        joined to the room then return the current state. If the user has
-        left the room return the state events from when they left.
-
-        Args:
-            user_id(str): The user requesting state events.
-            room_id(str): The room ID to get all state events from.
-        Returns:
-            A list of dicts representing state events. [{}, {}, {}]
-        """
-        membership, membership_event_id = yield self._check_in_room_or_world_readable(
-            room_id, user_id
-        )
-
-        if membership == Membership.JOIN:
-            room_state = yield self.state_handler.get_current_state(room_id)
-        elif membership == Membership.LEAVE:
-            room_state = yield self.store.get_state_for_events(
-                [membership_event_id], None
-            )
-            room_state = room_state[membership_event_id]
-
-        now = self.clock.time_msec()
-        defer.returnValue(
-            [serialize_event(c, now) for c in room_state.values()]
-        )
-
-    @defer.inlineCallbacks
-    def get_joined_members(self, requester, room_id):
-        """Get all the joined members in the room and their profile information.
-
-        If the user has left the room return the state events from when they left.
-
-        Args:
-            requester(Requester): The user requesting state events.
-            room_id(str): The room ID to get all state events from.
-        Returns:
-            A dict of user_id to profile info
-        """
-        user_id = requester.user.to_string()
-        if not requester.app_service:
-            # We check AS auth after fetching the room membership, as it
-            # requires us to pull out all joined members anyway.
-            membership, _ = yield self._check_in_room_or_world_readable(
-                room_id, user_id
-            )
-            if membership != Membership.JOIN:
-                raise NotImplementedError(
-                    "Getting joined members after leaving is not implemented"
-                )
-
-        users_with_profile = yield self.state.get_current_user_in_room(room_id)
-
-        # If this is an AS, double check that they are allowed to see the members.
-        # This can either be because the AS user is in the room or becuase there
-        # is a user in the room that the AS is "interested in"
-        if requester.app_service and user_id not in users_with_profile:
-            for uid in users_with_profile:
-                if requester.app_service.is_interested_in_user(uid):
-                    break
-            else:
-                # Loop fell through, AS has no interested users in room
-                raise AuthError(403, "Appservice not in room")
-
-        defer.returnValue({
-            user_id: {
-                "avatar_url": profile.avatar_url,
-                "display_name": profile.display_name,
-            }
-            for user_id, profile in users_with_profile.iteritems()
-        })
-
-    @measure_func("_create_new_client_event")
-    @defer.inlineCallbacks
-    def _create_new_client_event(self, builder, requester=None, prev_event_ids=None):
+    def create_new_client_event(self, builder, requester=None, prev_event_ids=None):
         if prev_event_ids:
             prev_events = yield self.store.add_event_hashes(prev_event_ids)
             prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
@@ -509,9 +520,7 @@ class MessageHandler(BaseHandler):
         builder.prev_events = prev_events
         builder.depth = depth
 
-        state_handler = self.state_handler
-
-        context = yield state_handler.compute_event_context(builder)
+        context = yield self.state.compute_event_context(builder)
         if requester:
             context.app_service = requester.app_service
 
@@ -551,7 +560,7 @@ class MessageHandler(BaseHandler):
         # We now need to go and hit out to wherever we need to hit out to.
 
         if ratelimit:
-            yield self.ratelimit(requester)
+            yield self.base_handler.ratelimit(requester)
 
         try:
             yield self.auth.check_from_context(event, context)
@@ -567,7 +576,7 @@ class MessageHandler(BaseHandler):
             logger.exception("Failed to encode content: %r", event.content)
             raise
 
-        yield self.maybe_kick_guest_users(event, context)
+        yield self.base_handler.maybe_kick_guest_users(event, context)
 
         if event.type == EventTypes.CanonicalAlias:
             # Check the alias is acually valid (at this time at least)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 5e5b1952dd..9800e24453 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -36,6 +36,8 @@ class ProfileHandler(BaseHandler):
             "profile", self.on_profile_query
         )
 
+        self.user_directory_handler = hs.get_user_directory_handler()
+
         self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS)
 
     @defer.inlineCallbacks
@@ -139,6 +141,12 @@ class ProfileHandler(BaseHandler):
             target_user.localpart, new_displayname
         )
 
+        if self.hs.config.user_directory_search_all_users:
+            profile = yield self.store.get_profileinfo(target_user.localpart)
+            yield self.user_directory_handler.handle_local_profile_change(
+                target_user.to_string(), profile
+            )
+
         yield self._update_join_states(requester, target_user)
 
     @defer.inlineCallbacks
@@ -183,6 +191,12 @@ class ProfileHandler(BaseHandler):
             target_user.localpart, new_avatar_url
         )
 
+        if self.hs.config.user_directory_search_all_users:
+            profile = yield self.store.get_profileinfo(target_user.localpart)
+            yield self.user_directory_handler.handle_local_profile_change(
+                target_user.to_string(), profile
+            )
+
         yield self._update_join_states(requester, target_user)
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index f6e7e58563..9021d4d57f 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -25,6 +25,7 @@ from synapse.http.client import CaptchaServerHttpClient
 from synapse import types
 from synapse.types import UserID
 from synapse.util.async import run_on_reactor
+from synapse.util.threepids import check_3pid_allowed
 from ._base import BaseHandler
 
 logger = logging.getLogger(__name__)
@@ -38,6 +39,7 @@ class RegistrationHandler(BaseHandler):
         self.auth = hs.get_auth()
         self._auth_handler = hs.get_auth_handler()
         self.profile_handler = hs.get_profile_handler()
+        self.user_directory_handler = hs.get_user_directory_handler()
         self.captcha_client = CaptchaServerHttpClient(hs)
 
         self._next_generated_user_id = None
@@ -130,7 +132,7 @@ class RegistrationHandler(BaseHandler):
         yield run_on_reactor()
         password_hash = None
         if password:
-            password_hash = self.auth_handler().hash(password)
+            password_hash = yield self.auth_handler().hash(password)
 
         if localpart:
             yield self.check_username(localpart, guest_access_token=guest_access_token)
@@ -165,6 +167,13 @@ class RegistrationHandler(BaseHandler):
                 ),
                 admin=admin,
             )
+
+            if self.hs.config.user_directory_search_all_users:
+                profile = yield self.store.get_profileinfo(localpart)
+                yield self.user_directory_handler.handle_local_profile_change(
+                    user_id, profile
+                )
+
         else:
             # autogen a sequential user ID
             attempts = 0
@@ -285,7 +294,7 @@ class RegistrationHandler(BaseHandler):
         """
 
         for c in threepidCreds:
-            logger.info("validating theeepidcred sid %s on id server %s",
+            logger.info("validating threepidcred sid %s on id server %s",
                         c['sid'], c['idServer'])
             try:
                 identity_handler = self.hs.get_handlers().identity_handler
@@ -299,6 +308,11 @@ class RegistrationHandler(BaseHandler):
             logger.info("got threepid with medium '%s' and address '%s'",
                         threepid['medium'], threepid['address'])
 
+            if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']):
+                raise RegistrationError(
+                    403, "Third party identifier is not allowed"
+                )
+
     @defer.inlineCallbacks
     def bind_emails(self, user_id, threepidCreds):
         """Links emails with a user ID and informs an identity server.
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 496f1fc39b..6ab020bf41 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -64,6 +65,7 @@ class RoomCreationHandler(BaseHandler):
         super(RoomCreationHandler, self).__init__(hs)
 
         self.spam_checker = hs.get_spam_checker()
+        self.event_creation_handler = hs.get_event_creation_handler()
 
     @defer.inlineCallbacks
     def create_room(self, requester, config, ratelimit=True):
@@ -163,13 +165,11 @@ class RoomCreationHandler(BaseHandler):
 
         creation_content = config.get("creation_content", {})
 
-        msg_handler = self.hs.get_handlers().message_handler
         room_member_handler = self.hs.get_handlers().room_member_handler
 
         yield self._send_events_for_new_room(
             requester,
             room_id,
-            msg_handler,
             room_member_handler,
             preset_config=preset_config,
             invite_list=invite_list,
@@ -181,7 +181,7 @@ class RoomCreationHandler(BaseHandler):
 
         if "name" in config:
             name = config["name"]
-            yield msg_handler.create_and_send_nonmember_event(
+            yield self.event_creation_handler.create_and_send_nonmember_event(
                 requester,
                 {
                     "type": EventTypes.Name,
@@ -194,7 +194,7 @@ class RoomCreationHandler(BaseHandler):
 
         if "topic" in config:
             topic = config["topic"]
-            yield msg_handler.create_and_send_nonmember_event(
+            yield self.event_creation_handler.create_and_send_nonmember_event(
                 requester,
                 {
                     "type": EventTypes.Topic,
@@ -205,12 +205,12 @@ class RoomCreationHandler(BaseHandler):
                 },
                 ratelimit=False)
 
-        content = {}
-        is_direct = config.get("is_direct", None)
-        if is_direct:
-            content["is_direct"] = is_direct
-
         for invitee in invite_list:
+            content = {}
+            is_direct = config.get("is_direct", None)
+            if is_direct:
+                content["is_direct"] = is_direct
+
             yield room_member_handler.update_membership(
                 requester,
                 UserID.from_string(invitee),
@@ -249,7 +249,6 @@ class RoomCreationHandler(BaseHandler):
             self,
             creator,  # A Requester object.
             room_id,
-            msg_handler,
             room_member_handler,
             preset_config,
             invite_list,
@@ -272,7 +271,7 @@ class RoomCreationHandler(BaseHandler):
         @defer.inlineCallbacks
         def send(etype, content, **kwargs):
             event = create(etype, content, **kwargs)
-            yield msg_handler.create_and_send_nonmember_event(
+            yield self.event_creation_handler.create_and_send_nonmember_event(
                 creator,
                 event,
                 ratelimit=False
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index bb40075387..dfa09141ed 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -203,7 +203,8 @@ class RoomListHandler(BaseHandler):
         if limit:
             step = limit + 1
         else:
-            step = len(rooms_to_scan)
+            # step cannot be zero
+            step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
 
         chunk = []
         for i in xrange(0, len(rooms_to_scan), step):
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 970fec0666..37dc5e99ab 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -46,6 +47,7 @@ class RoomMemberHandler(BaseHandler):
         super(RoomMemberHandler, self).__init__(hs)
 
         self.profile_handler = hs.get_profile_handler()
+        self.event_creation_hander = hs.get_event_creation_handler()
 
         self.member_linearizer = Linearizer(name="member")
 
@@ -66,13 +68,12 @@ class RoomMemberHandler(BaseHandler):
     ):
         if content is None:
             content = {}
-        msg_handler = self.hs.get_handlers().message_handler
 
         content["membership"] = membership
         if requester.is_guest:
             content["kind"] = "guest"
 
-        event, context = yield msg_handler.create_event(
+        event, context = yield self.event_creation_hander.create_event(
             requester,
             {
                 "type": EventTypes.Member,
@@ -90,12 +91,14 @@ class RoomMemberHandler(BaseHandler):
         )
 
         # Check if this event matches the previous membership event for the user.
-        duplicate = yield msg_handler.deduplicate_state_event(event, context)
+        duplicate = yield self.event_creation_hander.deduplicate_state_event(
+            event, context,
+        )
         if duplicate is not None:
             # Discard the new event since this membership change is a no-op.
             defer.returnValue(duplicate)
 
-        yield msg_handler.handle_new_client_event(
+        yield self.event_creation_hander.handle_new_client_event(
             requester,
             event,
             context,
@@ -189,6 +192,10 @@ class RoomMemberHandler(BaseHandler):
         content_specified = bool(content)
         if content is None:
             content = {}
+        else:
+            # We do a copy here as we potentially change some keys
+            # later on.
+            content = dict(content)
 
         effective_membership_state = action
         if action in ["kick", "unban"]:
@@ -390,8 +397,9 @@ class RoomMemberHandler(BaseHandler):
         else:
             requester = synapse.types.create_requester(target_user)
 
-        message_handler = self.hs.get_handlers().message_handler
-        prev_event = yield message_handler.deduplicate_state_event(event, context)
+        prev_event = yield self.event_creation_hander.deduplicate_state_event(
+            event, context,
+        )
         if prev_event is not None:
             return
 
@@ -408,7 +416,7 @@ class RoomMemberHandler(BaseHandler):
             if is_blocked:
                 raise SynapseError(403, "This room has been blocked on this server")
 
-        yield message_handler.handle_new_client_event(
+        yield self.event_creation_hander.handle_new_client_event(
             requester,
             event,
             context,
@@ -640,8 +648,7 @@ class RoomMemberHandler(BaseHandler):
             )
         )
 
-        msg_handler = self.hs.get_handlers().message_handler
-        yield msg_handler.create_and_send_nonmember_event(
+        yield self.event_creation_hander.create_and_send_nonmember_event(
             requester,
             {
                 "type": EventTypes.ThirdPartyInvite,
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
new file mode 100644
index 0000000000..e057ae54c9
--- /dev/null
+++ b/synapse/handlers/set_password.py
@@ -0,0 +1,56 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.errors import Codes, StoreError, SynapseError
+from ._base import BaseHandler
+
+logger = logging.getLogger(__name__)
+
+
+class SetPasswordHandler(BaseHandler):
+    """Handler which deals with changing user account passwords"""
+    def __init__(self, hs):
+        super(SetPasswordHandler, self).__init__(hs)
+        self._auth_handler = hs.get_auth_handler()
+        self._device_handler = hs.get_device_handler()
+
+    @defer.inlineCallbacks
+    def set_password(self, user_id, newpassword, requester=None):
+        password_hash = yield self._auth_handler.hash(newpassword)
+
+        except_device_id = requester.device_id if requester else None
+        except_access_token_id = requester.access_token_id if requester else None
+
+        try:
+            yield self.store.user_set_password_hash(user_id, password_hash)
+        except StoreError as e:
+            if e.code == 404:
+                raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
+            raise e
+
+        # we want to log out all of the user's other sessions. First delete
+        # all his other devices.
+        yield self._device_handler.delete_all_devices_for_user(
+            user_id, except_device_id=except_device_id,
+        )
+
+        # and now delete any access tokens which weren't associated with
+        # devices (or were associated with this device).
+        yield self._auth_handler.delete_access_tokens_for_user(
+            user_id, except_token_id=except_access_token_id,
+        )
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index b5be5d9623..714f0195c8 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -20,12 +20,13 @@ from synapse.api.constants import EventTypes, JoinRules, Membership
 from synapse.storage.roommember import ProfileInfo
 from synapse.util.metrics import Measure
 from synapse.util.async import sleep
+from synapse.types import get_localpart_from_id
 
 
 logger = logging.getLogger(__name__)
 
 
-class UserDirectoyHandler(object):
+class UserDirectoryHandler(object):
     """Handles querying of and keeping updated the user_directory.
 
     N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY
@@ -41,9 +42,10 @@ class UserDirectoyHandler(object):
     one public room.
     """
 
-    INITIAL_SLEEP_MS = 50
-    INITIAL_SLEEP_COUNT = 100
-    INITIAL_BATCH_SIZE = 100
+    INITIAL_ROOM_SLEEP_MS = 50
+    INITIAL_ROOM_SLEEP_COUNT = 100
+    INITIAL_ROOM_BATCH_SIZE = 100
+    INITIAL_USER_SLEEP_MS = 10
 
     def __init__(self, hs):
         self.store = hs.get_datastore()
@@ -53,6 +55,7 @@ class UserDirectoyHandler(object):
         self.notifier = hs.get_notifier()
         self.is_mine_id = hs.is_mine_id
         self.update_user_directory = hs.config.update_user_directory
+        self.search_all_users = hs.config.user_directory_search_all_users
 
         # When start up for the first time we need to populate the user_directory.
         # This is a set of user_id's we've inserted already
@@ -111,6 +114,15 @@ class UserDirectoyHandler(object):
             self._is_processing = False
 
     @defer.inlineCallbacks
+    def handle_local_profile_change(self, user_id, profile):
+        """Called to update index of our local user profiles when they change
+        irrespective of any rooms the user may be in.
+        """
+        yield self.store.update_profile_in_user_dir(
+            user_id, profile.display_name, profile.avatar_url, None,
+        )
+
+    @defer.inlineCallbacks
     def _unsafe_process(self):
         # If self.pos is None then means we haven't fetched it from DB
         if self.pos is None:
@@ -148,16 +160,30 @@ class UserDirectoyHandler(object):
         room_ids = yield self.store.get_all_rooms()
 
         logger.info("Doing initial update of user directory. %d rooms", len(room_ids))
-        num_processed_rooms = 1
+        num_processed_rooms = 0
 
         for room_id in room_ids:
-            logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids))
+            logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids))
             yield self._handle_initial_room(room_id)
             num_processed_rooms += 1
-            yield sleep(self.INITIAL_SLEEP_MS / 1000.)
+            yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
 
         logger.info("Processed all rooms.")
 
+        if self.search_all_users:
+            num_processed_users = 0
+            user_ids = yield self.store.get_all_local_users()
+            logger.info("Doing initial update of user directory. %d users", len(user_ids))
+            for user_id in user_ids:
+                # We add profiles for all users even if they don't match the
+                # include pattern, just in case we want to change it in future
+                logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids))
+                yield self._handle_local_user(user_id)
+                num_processed_users += 1
+                yield sleep(self.INITIAL_USER_SLEEP_MS / 1000.)
+
+            logger.info("Processed all users")
+
         self.initially_handled_users = None
         self.initially_handled_users_in_public = None
         self.initially_handled_users_share = None
@@ -201,8 +227,8 @@ class UserDirectoyHandler(object):
         to_update = set()
         count = 0
         for user_id in user_ids:
-            if count % self.INITIAL_SLEEP_COUNT == 0:
-                yield sleep(self.INITIAL_SLEEP_MS / 1000.)
+            if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
+                yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
 
             if not self.is_mine_id(user_id):
                 count += 1
@@ -216,8 +242,8 @@ class UserDirectoyHandler(object):
                 if user_id == other_user_id:
                     continue
 
-                if count % self.INITIAL_SLEEP_COUNT == 0:
-                    yield sleep(self.INITIAL_SLEEP_MS / 1000.)
+                if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
+                    yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
                 count += 1
 
                 user_set = (user_id, other_user_id)
@@ -237,13 +263,13 @@ class UserDirectoyHandler(object):
                 else:
                     self.initially_handled_users_share_private_room.add(user_set)
 
-                if len(to_insert) > self.INITIAL_BATCH_SIZE:
+                if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE:
                     yield self.store.add_users_who_share_room(
                         room_id, not is_public, to_insert,
                     )
                     to_insert.clear()
 
-                if len(to_update) > self.INITIAL_BATCH_SIZE:
+                if len(to_update) > self.INITIAL_ROOM_BATCH_SIZE:
                     yield self.store.update_users_who_share_room(
                         room_id, not is_public, to_update,
                     )
@@ -385,14 +411,28 @@ class UserDirectoyHandler(object):
                 yield self._handle_remove_user(room_id, user_id)
 
     @defer.inlineCallbacks
+    def _handle_local_user(self, user_id):
+        """Adds a new local roomless user into the user_directory_search table.
+        Used to populate up the user index when we have an
+        user_directory_search_all_users specified.
+        """
+        logger.debug("Adding new local user to dir, %r", user_id)
+
+        profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id))
+
+        row = yield self.store.get_user_in_directory(user_id)
+        if not row:
+            yield self.store.add_profiles_to_user_dir(None, {user_id: profile})
+
+    @defer.inlineCallbacks
     def _handle_new_user(self, room_id, user_id, profile):
         """Called when we might need to add user to directory
 
         Args:
-            room_id (str): room_id that user joined or started being public that
+            room_id (str): room_id that user joined or started being public
             user_id (str)
         """
-        logger.debug("Adding user to dir, %r", user_id)
+        logger.debug("Adding new user to dir, %r", user_id)
 
         row = yield self.store.get_user_in_directory(user_id)
         if not row:
@@ -407,7 +447,7 @@ class UserDirectoyHandler(object):
             if not row:
                 yield self.store.add_users_to_public_room(room_id, [user_id])
         else:
-            logger.debug("Not adding user to public dir, %r", user_id)
+            logger.debug("Not adding new user to public dir, %r", user_id)
 
         # Now we update users who share rooms with users. We do this by getting
         # all the current users in the room and seeing which aren't already
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 4abb479ae3..f3e4973c2e 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -18,6 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE
 from synapse.api.errors import (
     CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
 )
+from synapse.util.caches import CACHE_SIZE_FACTOR
 from synapse.util.logcontext import make_deferred_yieldable
 from synapse.util import logcontext
 import synapse.metrics
@@ -30,6 +31,7 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
 from twisted.web.client import (
     BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
     readBody, PartialDownloadError,
+    HTTPConnectionPool,
 )
 from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
 from twisted.web.http import PotentialDataLoss
@@ -64,13 +66,23 @@ class SimpleHttpClient(object):
     """
     def __init__(self, hs):
         self.hs = hs
+
+        pool = HTTPConnectionPool(reactor)
+
+        # the pusher makes lots of concurrent SSL connections to sygnal, and
+        # tends to do so in batches, so we need to allow the pool to keep lots
+        # of idle connections around.
+        pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
+        pool.cachedConnectionTimeout = 2 * 60
+
         # The default context factory in Twisted 14.0.0 (which we require) is
         # BrowserLikePolicyForHTTPS which will do regular cert validation
         # 'like a browser'
         self.agent = Agent(
             reactor,
             connectTimeout=15,
-            contextFactory=hs.get_http_client_context_factory()
+            contextFactory=hs.get_http_client_context_factory(),
+            pool=pool,
         )
         self.user_agent = hs.version_string
         self.clock = hs.get_clock()
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index a97532162f..87639b9151 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -357,13 +357,14 @@ def _get_hosts_for_srv_record(dns_client, host):
     def eb(res, record_type):
         if res.check(DNSNameError):
             return []
-        logger.warn("Error looking up %s for %s: %s",
-                    record_type, host, res, res.value)
+        logger.warn("Error looking up %s for %s: %s", record_type, host, res)
         return res
 
     # no logcontexts here, so we can safely fire these off and gatherResults
-    d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb)
-    d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb)
+    d1 = dns_client.lookupAddress(host).addCallbacks(
+        cb, eb, errbackArgs=("A", ))
+    d2 = dns_client.lookupIPV6Address(host).addCallbacks(
+        cb, eb, errbackArgs=("AAAA", ))
     results = yield defer.DeferredList(
         [d1, d2], consumeErrors=True)
 
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 833496b72d..9145405cb0 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -27,7 +27,7 @@ import synapse.metrics
 from canonicaljson import encode_canonical_json
 
 from synapse.api.errors import (
-    SynapseError, Codes, HttpResponseException,
+    SynapseError, Codes, HttpResponseException, FederationDeniedError,
 )
 
 from signedjson.sign import sign_json
@@ -123,11 +123,22 @@ class MatrixFederationHttpClient(object):
 
             Fails with ``HTTPRequestException``: if we get an HTTP response
                 code >= 300.
+
             Fails with ``NotRetryingDestination`` if we are not yet ready
                 to retry this server.
+
+            Fails with ``FederationDeniedError`` if this destination
+                is not on our federation whitelist
+
             (May also fail with plenty of other Exceptions for things like DNS
                 failures, connection failures, SSL failures.)
         """
+        if (
+            self.hs.config.federation_domain_whitelist and
+            destination not in self.hs.config.federation_domain_whitelist
+        ):
+            raise FederationDeniedError(destination)
+
         limiter = yield synapse.util.retryutils.get_retry_limiter(
             destination,
             self.clock,
@@ -308,6 +319,9 @@ class MatrixFederationHttpClient(object):
 
             Fails with ``NotRetryingDestination`` if we are not yet ready
             to retry this server.
+
+            Fails with ``FederationDeniedError`` if this destination
+            is not on our federation whitelist
         """
 
         if not json_data_callback:
@@ -368,6 +382,9 @@ class MatrixFederationHttpClient(object):
 
             Fails with ``NotRetryingDestination`` if we are not yet ready
             to retry this server.
+
+            Fails with ``FederationDeniedError`` if this destination
+            is not on our federation whitelist
         """
 
         def body_callback(method, url_bytes, headers_dict):
@@ -422,6 +439,9 @@ class MatrixFederationHttpClient(object):
 
             Fails with ``NotRetryingDestination`` if we are not yet ready
             to retry this server.
+
+            Fails with ``FederationDeniedError`` if this destination
+            is not on our federation whitelist
         """
         logger.debug("get_json args: %s", args)
 
@@ -475,6 +495,9 @@ class MatrixFederationHttpClient(object):
 
             Fails with ``NotRetryingDestination`` if we are not yet ready
             to retry this server.
+
+            Fails with ``FederationDeniedError`` if this destination
+            is not on our federation whitelist
         """
 
         response = yield self._request(
@@ -518,6 +541,9 @@ class MatrixFederationHttpClient(object):
 
             Fails with ``NotRetryingDestination`` if we are not yet ready
             to retry this server.
+
+            Fails with ``FederationDeniedError`` if this destination
+            is not on our federation whitelist
         """
 
         encoded_args = {}
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 3ca1c9947c..165c684d0d 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -28,6 +28,7 @@ from canonicaljson import (
 )
 
 from twisted.internet import defer
+from twisted.python import failure
 from twisted.web import server, resource
 from twisted.web.server import NOT_DONE_YET
 from twisted.web.util import redirectTo
@@ -41,36 +42,70 @@ logger = logging.getLogger(__name__)
 
 metrics = synapse.metrics.get_metrics_for(__name__)
 
-incoming_requests_counter = metrics.register_counter(
-    "requests",
+# total number of responses served, split by method/servlet/tag
+response_count = metrics.register_counter(
+    "response_count",
     labels=["method", "servlet", "tag"],
+    alternative_names=(
+        # the following are all deprecated aliases for the same metric
+        metrics.name_prefix + x for x in (
+            "_requests",
+            "_response_time:count",
+            "_response_ru_utime:count",
+            "_response_ru_stime:count",
+            "_response_db_txn_count:count",
+            "_response_db_txn_duration:count",
+        )
+    )
 )
+
 outgoing_responses_counter = metrics.register_counter(
     "responses",
     labels=["method", "code"],
 )
 
-response_timer = metrics.register_distribution(
-    "response_time",
-    labels=["method", "servlet", "tag"]
+response_timer = metrics.register_counter(
+    "response_time_seconds",
+    labels=["method", "servlet", "tag"],
+    alternative_names=(
+        metrics.name_prefix + "_response_time:total",
+    ),
 )
 
-response_ru_utime = metrics.register_distribution(
-    "response_ru_utime", labels=["method", "servlet", "tag"]
+response_ru_utime = metrics.register_counter(
+    "response_ru_utime_seconds", labels=["method", "servlet", "tag"],
+    alternative_names=(
+        metrics.name_prefix + "_response_ru_utime:total",
+    ),
 )
 
-response_ru_stime = metrics.register_distribution(
-    "response_ru_stime", labels=["method", "servlet", "tag"]
+response_ru_stime = metrics.register_counter(
+    "response_ru_stime_seconds", labels=["method", "servlet", "tag"],
+    alternative_names=(
+        metrics.name_prefix + "_response_ru_stime:total",
+    ),
 )
 
-response_db_txn_count = metrics.register_distribution(
-    "response_db_txn_count", labels=["method", "servlet", "tag"]
+response_db_txn_count = metrics.register_counter(
+    "response_db_txn_count", labels=["method", "servlet", "tag"],
+    alternative_names=(
+        metrics.name_prefix + "_response_db_txn_count:total",
+    ),
 )
 
-response_db_txn_duration = metrics.register_distribution(
-    "response_db_txn_duration", labels=["method", "servlet", "tag"]
+# seconds spent waiting for db txns, excluding scheduling time, when processing
+# this request
+response_db_txn_duration = metrics.register_counter(
+    "response_db_txn_duration_seconds", labels=["method", "servlet", "tag"],
+    alternative_names=(
+        metrics.name_prefix + "_response_db_txn_duration:total",
+    ),
 )
 
+# seconds spent waiting for a db connection, when processing this request
+response_db_sched_duration = metrics.register_counter(
+    "response_db_sched_duration_seconds", labels=["method", "servlet", "tag"]
+)
 
 _next_request_id = 0
 
@@ -106,6 +141,10 @@ def wrap_request_handler(request_handler, include_metrics=False):
         with LoggingContext(request_id) as request_context:
             with Measure(self.clock, "wrapped_request_handler"):
                 request_metrics = RequestMetrics()
+                # we start the request metrics timer here with an initial stab
+                # at the servlet name. For most requests that name will be
+                # JsonResource (or a subclass), and JsonResource._async_render
+                # will update it once it picks a servlet.
                 request_metrics.start(self.clock, name=self.__class__.__name__)
 
                 request_context.request = request_id
@@ -131,12 +170,17 @@ def wrap_request_handler(request_handler, include_metrics=False):
                             version_string=self.version_string,
                         )
                     except Exception:
-                        logger.exception(
-                            "Failed handle request %s.%s on %r: %r",
+                        # failure.Failure() fishes the original Failure out
+                        # of our stack, and thus gives us a sensible stack
+                        # trace.
+                        f = failure.Failure()
+                        logger.error(
+                            "Failed handle request %s.%s on %r: %r: %s",
                             request_handler.__module__,
                             request_handler.__name__,
                             self,
-                            request
+                            request,
+                            f.getTraceback().rstrip(),
                         )
                         respond_with_json(
                             request,
@@ -243,12 +287,23 @@ class JsonResource(HttpServer, resource.Resource):
             if not m:
                 continue
 
-            # We found a match! Trigger callback and then return the
-            # returned response. We pass both the request and any
-            # matched groups from the regex to the callback.
+            # We found a match! First update the metrics object to indicate
+            # which servlet is handling the request.
 
             callback = path_entry.callback
 
+            servlet_instance = getattr(callback, "__self__", None)
+            if servlet_instance is not None:
+                servlet_classname = servlet_instance.__class__.__name__
+            else:
+                servlet_classname = "%r" % callback
+
+            request_metrics.name = servlet_classname
+
+            # Now trigger the callback. If it returns a response, we send it
+            # here. If it throws an exception, that is handled by the wrapper
+            # installed by @request_handler.
+
             kwargs = intern_dict({
                 name: urllib.unquote(value).decode("UTF-8") if value else value
                 for name, value in m.groupdict().items()
@@ -259,30 +314,14 @@ class JsonResource(HttpServer, resource.Resource):
                 code, response = callback_return
                 self._send_response(request, code, response)
 
-            servlet_instance = getattr(callback, "__self__", None)
-            if servlet_instance is not None:
-                servlet_classname = servlet_instance.__class__.__name__
-            else:
-                servlet_classname = "%r" % callback
-
-            request_metrics.name = servlet_classname
-
             return
 
         # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
+        request_metrics.name = self.__class__.__name__ + ".UnrecognizedRequest"
         raise UnrecognizedRequestError()
 
     def _send_response(self, request, code, response_json_object,
                        response_code_message=None):
-        # could alternatively use request.notifyFinish() and flip a flag when
-        # the Deferred fires, but since the flag is RIGHT THERE it seems like
-        # a waste.
-        if request._disconnected:
-            logger.warn(
-                "Not sending response to request %s, already disconnected.",
-                request)
-            return
-
         outgoing_responses_counter.inc(request.method, str(code))
 
         # TODO: Only enable CORS for the requests that need it.
@@ -316,7 +355,7 @@ class RequestMetrics(object):
                 )
                 return
 
-        incoming_requests_counter.inc(request.method, self.name, tag)
+        response_count.inc(request.method, self.name, tag)
 
         response_timer.inc_by(
             clock.time_msec() - self.start, request.method,
@@ -335,7 +374,10 @@ class RequestMetrics(object):
             context.db_txn_count, request.method, self.name, tag
         )
         response_db_txn_duration.inc_by(
-            context.db_txn_duration, request.method, self.name, tag
+            context.db_txn_duration_ms / 1000., request.method, self.name, tag
+        )
+        response_db_sched_duration.inc_by(
+            context.db_sched_duration_ms / 1000., request.method, self.name, tag
         )
 
 
@@ -358,6 +400,15 @@ class RootRedirect(resource.Resource):
 def respond_with_json(request, code, json_object, send_cors=False,
                       response_code_message=None, pretty_print=False,
                       version_string="", canonical_json=True):
+    # could alternatively use request.notifyFinish() and flip a flag when
+    # the Deferred fires, but since the flag is RIGHT THERE it seems like
+    # a waste.
+    if request._disconnected:
+        logger.warn(
+            "Not sending response to request %s, already disconnected.",
+            request)
+        return
+
     if pretty_print:
         json_bytes = encode_pretty_printed_json(json_object) + "\n"
     else:
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 71420e54db..ef8e62901b 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -148,11 +148,13 @@ def parse_string_from_args(args, name, default=None, required=False,
             return default
 
 
-def parse_json_value_from_request(request):
+def parse_json_value_from_request(request, allow_empty_body=False):
     """Parse a JSON value from the body of a twisted HTTP request.
 
     Args:
         request: the twisted HTTP request.
+        allow_empty_body (bool): if True, an empty body will be accepted and
+            turned into None
 
     Returns:
         The JSON value.
@@ -165,6 +167,9 @@ def parse_json_value_from_request(request):
     except Exception:
         raise SynapseError(400, "Error reading JSON content.")
 
+    if not content_bytes and allow_empty_body:
+        return None
+
     try:
         content = simplejson.loads(content_bytes)
     except Exception as e:
@@ -174,17 +179,24 @@ def parse_json_value_from_request(request):
     return content
 
 
-def parse_json_object_from_request(request):
+def parse_json_object_from_request(request, allow_empty_body=False):
     """Parse a JSON object from the body of a twisted HTTP request.
 
     Args:
         request: the twisted HTTP request.
+        allow_empty_body (bool): if True, an empty body will be accepted and
+            turned into an empty dict.
 
     Raises:
         SynapseError if the request body couldn't be decoded as JSON or
             if it wasn't a JSON object.
     """
-    content = parse_json_value_from_request(request)
+    content = parse_json_value_from_request(
+        request, allow_empty_body=allow_empty_body,
+    )
+
+    if allow_empty_body and content is None:
+        return {}
 
     if type(content) != dict:
         message = "Content must be a JSON object."
diff --git a/synapse/http/site.py b/synapse/http/site.py
index cd1492b1c3..e422c8dfae 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -66,14 +66,15 @@ class SynapseRequest(Request):
             context = LoggingContext.current_context()
             ru_utime, ru_stime = context.get_resource_usage()
             db_txn_count = context.db_txn_count
-            db_txn_duration = context.db_txn_duration
+            db_txn_duration_ms = context.db_txn_duration_ms
+            db_sched_duration_ms = context.db_sched_duration_ms
         except Exception:
             ru_utime, ru_stime = (0, 0)
-            db_txn_count, db_txn_duration = (0, 0)
+            db_txn_count, db_txn_duration_ms = (0, 0)
 
         self.site.access_logger.info(
             "%s - %s - {%s}"
-            " Processed request: %dms (%dms, %dms) (%dms/%d)"
+            " Processed request: %dms (%dms, %dms) (%dms/%dms/%d)"
             " %sB %s \"%s %s %s\" \"%s\"",
             self.getClientIP(),
             self.site.site_tag,
@@ -81,7 +82,8 @@ class SynapseRequest(Request):
             int(time.time() * 1000) - self.start_time,
             int(ru_utime * 1000),
             int(ru_stime * 1000),
-            int(db_txn_duration * 1000),
+            db_sched_duration_ms,
+            db_txn_duration_ms,
             int(db_txn_count),
             self.sentLength,
             self.code,
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 2265e6e8d6..e0cfb7d08f 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -146,10 +146,15 @@ def runUntilCurrentTimer(func):
             num_pending += 1
 
         num_pending += len(reactor.threadCallQueue)
-
         start = time.time() * 1000
         ret = func(*args, **kwargs)
         end = time.time() * 1000
+
+        # record the amount of wallclock time spent running pending calls.
+        # This is a proxy for the actual amount of time between reactor polls,
+        # since about 25% of time is actually spent running things triggered by
+        # I/O events, but that is harder to capture without rewriting half the
+        # reactor.
         tick_time.inc_by(end - start)
         pending_calls_metric.inc_by(num_pending)
 
diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py
index e87b2b80a7..ff5aa8c0e1 100644
--- a/synapse/metrics/metric.py
+++ b/synapse/metrics/metric.py
@@ -15,18 +15,38 @@
 
 
 from itertools import chain
+import logging
 
+logger = logging.getLogger(__name__)
 
-# TODO(paul): I can't believe Python doesn't have one of these
-def map_concat(func, items):
-    # flatten a list-of-lists
-    return list(chain.from_iterable(map(func, items)))
+
+def flatten(items):
+    """Flatten a list of lists
+
+    Args:
+        items: iterable[iterable[X]]
+
+    Returns:
+        list[X]: flattened list
+    """
+    return list(chain.from_iterable(items))
 
 
 class BaseMetric(object):
+    """Base class for metrics which report a single value per label set
+    """
 
-    def __init__(self, name, labels=[]):
-        self.name = name
+    def __init__(self, name, labels=[], alternative_names=[]):
+        """
+        Args:
+            name (str): principal name for this metric
+            labels (list(str)): names of the labels which will be reported
+                for this metric
+            alternative_names (iterable(str)): list of alternative names for
+                 this metric. This can be useful to provide a migration path
+                when renaming metrics.
+        """
+        self._names = [name] + list(alternative_names)
         self.labels = labels  # OK not to clone as we never write it
 
     def dimension(self):
@@ -36,7 +56,7 @@ class BaseMetric(object):
         return not len(self.labels)
 
     def _render_labelvalue(self, value):
-        # TODO: some kind of value escape
+        # TODO: escape backslashes, quotes and newlines
         return '"%s"' % (value)
 
     def _render_key(self, values):
@@ -47,19 +67,60 @@ class BaseMetric(object):
                       for k, v in zip(self.labels, values)])
         )
 
+    def _render_for_labels(self, label_values, value):
+        """Render this metric for a single set of labels
+
+        Args:
+            label_values (list[str]): values for each of the labels
+            value: value of the metric at with these labels
+
+        Returns:
+            iterable[str]: rendered metric
+        """
+        rendered_labels = self._render_key(label_values)
+        return (
+            "%s%s %.12g" % (name, rendered_labels, value)
+            for name in self._names
+        )
+
+    def render(self):
+        """Render this metric
+
+        Each metric is rendered as:
+
+            name{label1="val1",label2="val2"} value
+
+        https://prometheus.io/docs/instrumenting/exposition_formats/#text-format-details
+
+        Returns:
+            iterable[str]: rendered metrics
+        """
+        raise NotImplementedError()
+
 
 class CounterMetric(BaseMetric):
     """The simplest kind of metric; one that stores a monotonically-increasing
-    integer that counts events."""
+    value that counts events or running totals.
+
+    Example use cases for Counters:
+    - Number of requests processed
+    - Number of items that were inserted into a queue
+    - Total amount of data that a system has processed
+    Counters can only go up (and be reset when the process restarts).
+    """
 
     def __init__(self, *args, **kwargs):
         super(CounterMetric, self).__init__(*args, **kwargs)
 
+        # dict[list[str]]: value for each set of label values. the keys are the
+        # label values, in the same order as the labels in self.labels.
+        #
+        # (if the metric is a scalar, the (single) key is the empty list).
         self.counts = {}
 
         # Scalar metrics are never empty
         if self.is_scalar():
-            self.counts[()] = 0
+            self.counts[()] = 0.
 
     def inc_by(self, incr, *values):
         if len(values) != self.dimension():
@@ -77,11 +138,11 @@ class CounterMetric(BaseMetric):
     def inc(self, *values):
         self.inc_by(1, *values)
 
-    def render_item(self, k):
-        return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])]
-
     def render(self):
-        return map_concat(self.render_item, sorted(self.counts.keys()))
+        return flatten(
+            self._render_for_labels(k, self.counts[k])
+            for k in sorted(self.counts.keys())
+        )
 
 
 class CallbackMetric(BaseMetric):
@@ -95,13 +156,19 @@ class CallbackMetric(BaseMetric):
         self.callback = callback
 
     def render(self):
-        value = self.callback()
+        try:
+            value = self.callback()
+        except Exception:
+            logger.exception("Failed to render %s", self.name)
+            return ["# FAILED to render " + self.name]
 
         if self.is_scalar():
-            return ["%s %.12g" % (self.name, value)]
+            return list(self._render_for_labels([], value))
 
-        return ["%s%s %.12g" % (self.name, self._render_key(k), value[k])
-                for k in sorted(value.keys())]
+        return flatten(
+            self._render_for_labels(k, value[k])
+            for k in sorted(value.keys())
+        )
 
 
 class DistributionMetric(object):
@@ -126,7 +193,9 @@ class DistributionMetric(object):
 
 
 class CacheMetric(object):
-    __slots__ = ("name", "cache_name", "hits", "misses", "size_callback")
+    __slots__ = (
+        "name", "cache_name", "hits", "misses", "evicted_size", "size_callback",
+    )
 
     def __init__(self, name, size_callback, cache_name):
         self.name = name
@@ -134,6 +203,7 @@ class CacheMetric(object):
 
         self.hits = 0
         self.misses = 0
+        self.evicted_size = 0
 
         self.size_callback = size_callback
 
@@ -143,6 +213,9 @@ class CacheMetric(object):
     def inc_misses(self):
         self.misses += 1
 
+    def inc_evictions(self, size=1):
+        self.evicted_size += size
+
     def render(self):
         size = self.size_callback()
         hits = self.hits
@@ -152,6 +225,9 @@ class CacheMetric(object):
             """%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
             """%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
             """%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
+            """%s:evicted_size{name="%s"} %d""" % (
+                self.name, self.cache_name, self.evicted_size
+            ),
         ]
 
 
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index dc680ddf43..097c844d31 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from twisted.internet import defer
 
 from synapse.types import UserID
 
@@ -81,6 +82,7 @@ class ModuleApi(object):
         reg = self.hs.get_handlers().registration_handler
         return reg.register(localpart=localpart)
 
+    @defer.inlineCallbacks
     def invalidate_access_token(self, access_token):
         """Invalidate an access token for a user
 
@@ -94,8 +96,16 @@ class ModuleApi(object):
         Raises:
             synapse.api.errors.AuthError: the access token is invalid
         """
-
-        return self._auth_handler.delete_access_token(access_token)
+        # see if the access token corresponds to a device
+        user_info = yield self._auth.get_user_by_access_token(access_token)
+        device_id = user_info.get("device_id")
+        user_id = user_info["user"].to_string()
+        if device_id:
+            # delete the device, which will also delete its access tokens
+            yield self.hs.get_device_handler().delete_device(user_id, device_id)
+        else:
+            # no associated device. Just delete the access token.
+            yield self._auth_handler.delete_access_token(access_token)
 
     def run_db_interaction(self, desc, func, *args, **kwargs):
         """Run a function with a database connection
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 626da778cd..ef042681bc 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -255,9 +255,7 @@ class Notifier(object):
         )
 
         if self.federation_sender:
-            preserve_fn(self.federation_sender.notify_new_events)(
-                room_stream_id
-            )
+            self.federation_sender.notify_new_events(room_stream_id)
 
         if event.type == EventTypes.Member and event.membership == Membership.JOIN:
             self._user_joined_room(event.state_key, event.room_id)
@@ -297,8 +295,7 @@ class Notifier(object):
     def on_new_replication_data(self):
         """Used to inform replication listeners that something has happend
         without waking up any of the normal user event streams"""
-        with PreserveLoggingContext():
-            self.notify_replication()
+        self.notify_replication()
 
     @defer.inlineCallbacks
     def wait_for_events(self, user_id, timeout, callback, room_ids=None,
@@ -516,8 +513,14 @@ class Notifier(object):
             self.replication_deferred = ObservableDeferred(defer.Deferred())
             deferred.callback(None)
 
-        for cb in self.replication_callbacks:
-            preserve_fn(cb)()
+            # the callbacks may well outlast the current request, so we run
+            # them in the sentinel logcontext.
+            #
+            # (ideally it would be up to the callbacks to know if they were
+            # starting off background processes and drop the logcontext
+            # accordingly, but that requires more changes)
+            for cb in self.replication_callbacks:
+                cb()
 
     @defer.inlineCallbacks
     def wait_for_replication(self, callback, timeout):
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index c16f61452c..2cbac571b8 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -13,21 +13,30 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-from synapse.push import PusherConfigException
+import logging
 
 from twisted.internet import defer, reactor
 from twisted.internet.error import AlreadyCalled, AlreadyCancelled
 
-import logging
 import push_rule_evaluator
 import push_tools
-
+import synapse
+from synapse.push import PusherConfigException
 from synapse.util.logcontext import LoggingContext
 from synapse.util.metrics import Measure
 
 logger = logging.getLogger(__name__)
 
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+http_push_processed_counter = metrics.register_counter(
+    "http_pushes_processed",
+)
+
+http_push_failed_counter = metrics.register_counter(
+    "http_pushes_failed",
+)
+
 
 class HttpPusher(object):
     INITIAL_BACKOFF_SEC = 1  # in seconds because that's what Twisted takes
@@ -152,9 +161,16 @@ class HttpPusher(object):
             self.user_id, self.last_stream_ordering, self.max_stream_ordering
         )
 
+        logger.info(
+            "Processing %i unprocessed push actions for %s starting at "
+            "stream_ordering %s",
+            len(unprocessed), self.name, self.last_stream_ordering,
+        )
+
         for push_action in unprocessed:
             processed = yield self._process_one(push_action)
             if processed:
+                http_push_processed_counter.inc()
                 self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
                 self.last_stream_ordering = push_action['stream_ordering']
                 yield self.store.update_pusher_last_stream_ordering_and_success(
@@ -169,6 +185,7 @@ class HttpPusher(object):
                         self.failing_since
                     )
             else:
+                http_push_failed_counter.inc()
                 if not self.failing_since:
                     self.failing_since = self.clock.time_msec()
                     yield self.store.update_pusher_failing_since(
@@ -316,7 +333,10 @@ class HttpPusher(object):
         try:
             resp = yield self.http_client.post_json_get_json(self.url, notification_dict)
         except Exception:
-            logger.warn("Failed to push %s ", self.url)
+            logger.warn(
+                "Failed to push event %s to %s",
+                event.event_id, self.name, exc_info=True,
+            )
             defer.returnValue(False)
         rejected = []
         if 'rejected' in resp:
@@ -325,7 +345,7 @@ class HttpPusher(object):
 
     @defer.inlineCallbacks
     def _send_badge(self, badge):
-        logger.info("Sending updated badge count %d to %r", badge, self.user_id)
+        logger.info("Sending updated badge count %d to %s", badge, self.name)
         d = {
             'notification': {
                 'id': '',
@@ -347,7 +367,10 @@ class HttpPusher(object):
         try:
             resp = yield self.http_client.post_json_get_json(self.url, d)
         except Exception:
-            logger.exception("Failed to push %s ", self.url)
+            logger.warn(
+                "Failed to send badge count to %s",
+                self.name, exc_info=True,
+            )
             defer.returnValue(False)
         rejected = []
         if 'rejected' in resp:
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 34cb108dcb..134e89b371 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -103,19 +103,25 @@ class PusherPool:
                 yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
 
     @defer.inlineCallbacks
-    def remove_pushers_by_user(self, user_id, except_access_token_id=None):
-        all = yield self.store.get_all_pushers()
-        logger.info(
-            "Removing all pushers for user %s except access tokens id %r",
-            user_id, except_access_token_id
-        )
-        for p in all:
-            if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
+    def remove_pushers_by_access_token(self, user_id, access_tokens):
+        """Remove the pushers for a given user corresponding to a set of
+        access_tokens.
+
+        Args:
+            user_id (str): user to remove pushers for
+            access_tokens (Iterable[int]): access token *ids* to remove pushers
+                for
+        """
+        tokens = set(access_tokens)
+        for p in (yield self.store.get_pushers_by_user_id(user_id)):
+            if p['access_token'] in tokens:
                 logger.info(
                     "Removing pusher for app id %s, pushkey %s, user %s",
                     p['app_id'], p['pushkey'], p['user_name']
                 )
-                yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
+                yield self.remove_pusher(
+                    p['app_id'], p['pushkey'], p['user_name'],
+                )
 
     @defer.inlineCallbacks
     def on_new_notifications(self, min_stream_id, max_stream_id):
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 29d7296b43..8acb5df0f3 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -19,7 +19,7 @@ from synapse.storage import DataStore
 from synapse.storage.event_federation import EventFederationStore
 from synapse.storage.event_push_actions import EventPushActionsStore
 from synapse.storage.roommember import RoomMemberStore
-from synapse.storage.state import StateGroupReadStore
+from synapse.storage.state import StateGroupWorkerStore
 from synapse.storage.stream import StreamStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from ._base import BaseSlavedStore
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
 # the method descriptor on the DataStore and chuck them into our class.
 
 
-class SlavedEventStore(StateGroupReadStore, BaseSlavedStore):
+class SlavedEventStore(StateGroupWorkerStore, BaseSlavedStore):
 
     def __init__(self, db_conn, hs):
         super(SlavedEventStore, self).__init__(db_conn, hs)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index d59503b905..0a9a290af4 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -517,25 +517,28 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
             self.send_error("Wrong remote")
 
     def on_RDATA(self, cmd):
+        stream_name = cmd.stream_name
+        inbound_rdata_count.inc(stream_name)
+
         try:
-            row = STREAMS_MAP[cmd.stream_name].ROW_TYPE(*cmd.row)
+            row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row)
         except Exception:
             logger.exception(
                 "[%s] Failed to parse RDATA: %r %r",
-                self.id(), cmd.stream_name, cmd.row
+                self.id(), stream_name, cmd.row
             )
             raise
 
         if cmd.token is None:
             # I.e. this is part of a batch of updates for this stream. Batch
             # until we get an update for the stream with a non None token
-            self.pending_batches.setdefault(cmd.stream_name, []).append(row)
+            self.pending_batches.setdefault(stream_name, []).append(row)
         else:
             # Check if this is the last of a batch of updates
-            rows = self.pending_batches.pop(cmd.stream_name, [])
+            rows = self.pending_batches.pop(stream_name, [])
             rows.append(row)
 
-            self.handler.on_rdata(cmd.stream_name, cmd.token, rows)
+            self.handler.on_rdata(stream_name, cmd.token, rows)
 
     def on_POSITION(self, cmd):
         self.handler.on_position(cmd.stream_name, cmd.token)
@@ -644,3 +647,9 @@ metrics.register_callback(
     },
     labels=["command", "name", "conn_id"],
 )
+
+# number of updates received for each RDATA stream
+inbound_rdata_count = metrics.register_counter(
+    "inbound_rdata_count",
+    labels=["stream_name"],
+)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 1d03e79b85..786c3fe864 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -216,11 +216,12 @@ class ReplicationStreamer(object):
             self.federation_sender.federation_ack(token)
 
     @measure_func("repl.on_user_sync")
+    @defer.inlineCallbacks
     def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
         """A client has started/stopped syncing on a worker.
         """
         user_sync_counter.inc()
-        self.presence_handler.update_external_syncs_row(
+        yield self.presence_handler.update_external_syncs_row(
             conn_id, user_id, is_syncing, last_sync_ms,
         )
 
@@ -244,11 +245,12 @@ class ReplicationStreamer(object):
         getattr(self.store, cache_func).invalidate(tuple(keys))
 
     @measure_func("repl.on_user_ip")
+    @defer.inlineCallbacks
     def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
         """The client saw a user request
         """
         user_ip_cache_counter.inc()
-        self.store.insert_client_ip(
+        yield self.store.insert_client_ip(
             user_id, access_token, ip, user_agent, device_id, last_seen,
         )
 
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 1197158fdc..2ad486c67d 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -128,7 +129,16 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
         if not is_admin:
             raise AuthError(403, "You are not a server admin")
 
-        yield self.handlers.message_handler.purge_history(room_id, event_id)
+        body = parse_json_object_from_request(request, allow_empty_body=True)
+
+        delete_local_events = bool(
+            body.get("delete_local_history", False)
+        )
+
+        yield self.handlers.message_handler.purge_history(
+            room_id, event_id,
+            delete_local_events=delete_local_events,
+        )
 
         defer.returnValue((200, {}))
 
@@ -137,8 +147,8 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
 
     def __init__(self, hs):
-        self._auth_handler = hs.get_auth_handler()
         super(DeactivateAccountRestServlet, self).__init__(hs)
+        self._deactivate_account_handler = hs.get_deactivate_account_handler()
 
     @defer.inlineCallbacks
     def on_POST(self, request, target_user_id):
@@ -149,7 +159,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
         if not is_admin:
             raise AuthError(403, "You are not a server admin")
 
-        yield self._auth_handler.deactivate_account(target_user_id)
+        yield self._deactivate_account_handler.deactivate_account(target_user_id)
         defer.returnValue((200, {}))
 
 
@@ -171,6 +181,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
         self.store = hs.get_datastore()
         self.handlers = hs.get_handlers()
         self.state = hs.get_state_handler()
+        self.event_creation_handler = hs.get_event_creation_handler()
 
     @defer.inlineCallbacks
     def on_POST(self, request, room_id):
@@ -203,8 +214,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
         )
         new_room_id = info["room_id"]
 
-        msg_handler = self.handlers.message_handler
-        yield msg_handler.create_and_send_nonmember_event(
+        yield self.event_creation_handler.create_and_send_nonmember_event(
             room_creator_requester,
             {
                 "type": "m.room.message",
@@ -289,6 +299,27 @@ class QuarantineMediaInRoom(ClientV1RestServlet):
         defer.returnValue((200, {"num_quarantined": num_quarantined}))
 
 
+class ListMediaInRoom(ClientV1RestServlet):
+    """Lists all of the media in a given room.
+    """
+    PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media")
+
+    def __init__(self, hs):
+        super(ListMediaInRoom, self).__init__(hs)
+        self.store = hs.get_datastore()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, room_id):
+        requester = yield self.auth.get_user_by_req(request)
+        is_admin = yield self.auth.is_server_admin(requester.user)
+        if not is_admin:
+            raise AuthError(403, "You are not a server admin")
+
+        local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
+
+        defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
+
+
 class ResetPasswordRestServlet(ClientV1RestServlet):
     """Post request to allow an administrator reset password for a user.
     This needs user to have administrator access in Synapse.
@@ -309,7 +340,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
         super(ResetPasswordRestServlet, self).__init__(hs)
         self.hs = hs
         self.auth = hs.get_auth()
-        self.auth_handler = hs.get_auth_handler()
+        self._set_password_handler = hs.get_set_password_handler()
 
     @defer.inlineCallbacks
     def on_POST(self, request, target_user_id):
@@ -330,7 +361,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
 
         logger.info("new_password: %r", new_password)
 
-        yield self.auth_handler.set_password(
+        yield self._set_password_handler.set_password(
             target_user_id, new_password, requester
         )
         defer.returnValue((200, {}))
@@ -487,3 +518,4 @@ def register_servlets(hs, http_server):
     SearchUsersRestServlet(hs).register(http_server)
     ShutdownRoomRestServlet(hs).register(http_server)
     QuarantineMediaInRoom(hs).register(http_server)
+    ListMediaInRoom(hs).register(http_server)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 5669ecb724..45844aa2d2 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -191,19 +191,25 @@ class LoginRestServlet(ClientV1RestServlet):
 
         # convert threepid identifiers to user IDs
         if identifier["type"] == "m.id.thirdparty":
-            if 'medium' not in identifier or 'address' not in identifier:
+            address = identifier.get('address')
+            medium = identifier.get('medium')
+
+            if medium is None or address is None:
                 raise SynapseError(400, "Invalid thirdparty identifier")
 
-            address = identifier['address']
-            if identifier['medium'] == 'email':
+            if medium == 'email':
                 # For emails, transform the address to lowercase.
                 # We store all email addreses as lowercase in the DB.
                 # (See add_threepid in synapse/handlers/auth.py)
                 address = address.lower()
             user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
-                identifier['medium'], address
+                medium, address,
             )
             if not user_id:
+                logger.warn(
+                    "unknown 3pid identifier medium %s, address %r",
+                    medium, address,
+                )
                 raise LoginError(403, "", errcode=Codes.FORBIDDEN)
 
             identifier = {
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 6add754782..ca49955935 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -16,6 +16,7 @@
 from twisted.internet import defer
 
 from synapse.api.auth import get_access_token_from_request
+from synapse.api.errors import AuthError
 
 from .base import ClientV1RestServlet, client_path_patterns
 
@@ -30,15 +31,30 @@ class LogoutRestServlet(ClientV1RestServlet):
 
     def __init__(self, hs):
         super(LogoutRestServlet, self).__init__(hs)
+        self._auth = hs.get_auth()
         self._auth_handler = hs.get_auth_handler()
+        self._device_handler = hs.get_device_handler()
 
     def on_OPTIONS(self, request):
         return (200, {})
 
     @defer.inlineCallbacks
     def on_POST(self, request):
-        access_token = get_access_token_from_request(request)
-        yield self._auth_handler.delete_access_token(access_token)
+        try:
+            requester = yield self.auth.get_user_by_req(request)
+        except AuthError:
+            # this implies the access token has already been deleted.
+            pass
+        else:
+            if requester.device_id is None:
+                # the acccess token wasn't associated with a device.
+                # Just delete the access token
+                access_token = get_access_token_from_request(request)
+                yield self._auth_handler.delete_access_token(access_token)
+            else:
+                yield self._device_handler.delete_device(
+                    requester.user.to_string(), requester.device_id)
+
         defer.returnValue((200, {}))
 
 
@@ -49,6 +65,7 @@ class LogoutAllRestServlet(ClientV1RestServlet):
         super(LogoutAllRestServlet, self).__init__(hs)
         self.auth = hs.get_auth()
         self._auth_handler = hs.get_auth_handler()
+        self._device_handler = hs.get_device_handler()
 
     def on_OPTIONS(self, request):
         return (200, {})
@@ -57,6 +74,12 @@ class LogoutAllRestServlet(ClientV1RestServlet):
     def on_POST(self, request):
         requester = yield self.auth.get_user_by_req(request)
         user_id = requester.user.to_string()
+
+        # first delete all of the user's devices
+        yield self._device_handler.delete_all_devices_for_user(user_id)
+
+        # .. and then delete any access tokens which weren't associated with
+        # devices.
         yield self._auth_handler.delete_access_tokens_for_user(user_id)
         defer.returnValue((200, {}))
 
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index 32ed1d3ab2..5c5fa8f7ab 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -70,10 +70,15 @@ class RegisterRestServlet(ClientV1RestServlet):
         self.handlers = hs.get_handlers()
 
     def on_GET(self, request):
+
+        require_email = 'email' in self.hs.config.registrations_require_3pid
+        require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
+
+        flows = []
         if self.hs.config.enable_registration_captcha:
-            return (
-                200,
-                {"flows": [
+            # only support the email-only flow if we don't require MSISDN 3PIDs
+            if not require_msisdn:
+                flows.extend([
                     {
                         "type": LoginType.RECAPTCHA,
                         "stages": [
@@ -82,27 +87,34 @@ class RegisterRestServlet(ClientV1RestServlet):
                             LoginType.PASSWORD
                         ]
                     },
+                ])
+            # only support 3PIDless registration if no 3PIDs are required
+            if not require_email and not require_msisdn:
+                flows.extend([
                     {
                         "type": LoginType.RECAPTCHA,
                         "stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
                     }
-                ]}
-            )
+                ])
         else:
-            return (
-                200,
-                {"flows": [
+            # only support the email-only flow if we don't require MSISDN 3PIDs
+            if require_email or not require_msisdn:
+                flows.extend([
                     {
                         "type": LoginType.EMAIL_IDENTITY,
                         "stages": [
                             LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
                         ]
-                    },
+                    }
+                ])
+            # only support 3PIDless registration if no 3PIDs are required
+            if not require_email and not require_msisdn:
+                flows.extend([
                     {
                         "type": LoginType.PASSWORD
                     }
-                ]}
-            )
+                ])
+        return (200, {"flows": flows})
 
     @defer.inlineCallbacks
     def on_POST(self, request):
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 75b735b47d..fbb2fc36e4 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -82,6 +83,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
     def __init__(self, hs):
         super(RoomStateEventRestServlet, self).__init__(hs)
         self.handlers = hs.get_handlers()
+        self.event_creation_hander = hs.get_event_creation_handler()
 
     def register(self, http_server):
         # /room/$roomid/state/$eventtype
@@ -162,15 +164,16 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
                 content=content,
             )
         else:
-            msg_handler = self.handlers.message_handler
-            event, context = yield msg_handler.create_event(
+            event, context = yield self.event_creation_hander.create_event(
                 requester,
                 event_dict,
                 token_id=requester.access_token_id,
                 txn_id=txn_id,
             )
 
-            yield msg_handler.send_nonmember_event(requester, event, context)
+            yield self.event_creation_hander.send_nonmember_event(
+                requester, event, context,
+            )
 
         ret = {}
         if event:
@@ -184,6 +187,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
     def __init__(self, hs):
         super(RoomSendEventRestServlet, self).__init__(hs)
         self.handlers = hs.get_handlers()
+        self.event_creation_hander = hs.get_event_creation_handler()
 
     def register(self, http_server):
         # /rooms/$roomid/send/$event_type[/$txn_id]
@@ -195,15 +199,19 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
         requester = yield self.auth.get_user_by_req(request, allow_guest=True)
         content = parse_json_object_from_request(request)
 
-        msg_handler = self.handlers.message_handler
-        event = yield msg_handler.create_and_send_nonmember_event(
+        event_dict = {
+            "type": event_type,
+            "content": content,
+            "room_id": room_id,
+            "sender": requester.user.to_string(),
+        }
+
+        if 'ts' in request.args and requester.app_service:
+            event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
+
+        event = yield self.event_creation_hander.create_and_send_nonmember_event(
             requester,
-            {
-                "type": event_type,
-                "content": content,
-                "room_id": room_id,
-                "sender": requester.user.to_string(),
-            },
+            event_dict,
             txn_id=txn_id,
         )
 
@@ -487,13 +495,35 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
         defer.returnValue((200, content))
 
 
-class RoomEventContext(ClientV1RestServlet):
+class RoomEventServlet(ClientV1RestServlet):
+    PATTERNS = client_path_patterns(
+        "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$"
+    )
+
+    def __init__(self, hs):
+        super(RoomEventServlet, self).__init__(hs)
+        self.clock = hs.get_clock()
+        self.event_handler = hs.get_event_handler()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, room_id, event_id):
+        requester = yield self.auth.get_user_by_req(request)
+        event = yield self.event_handler.get_event(requester.user, event_id)
+
+        time_now = self.clock.time_msec()
+        if event:
+            defer.returnValue((200, serialize_event(event, time_now)))
+        else:
+            defer.returnValue((404, "Event not found."))
+
+
+class RoomEventContextServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns(
         "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
     )
 
     def __init__(self, hs):
-        super(RoomEventContext, self).__init__(hs)
+        super(RoomEventContextServlet, self).__init__(hs)
         self.clock = hs.get_clock()
         self.handlers = hs.get_handlers()
 
@@ -643,6 +673,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
     def __init__(self, hs):
         super(RoomRedactEventRestServlet, self).__init__(hs)
         self.handlers = hs.get_handlers()
+        self.event_creation_handler = hs.get_event_creation_handler()
 
     def register(self, http_server):
         PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
@@ -653,8 +684,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
         requester = yield self.auth.get_user_by_req(request)
         content = parse_json_object_from_request(request)
 
-        msg_handler = self.handlers.message_handler
-        event = yield msg_handler.create_and_send_nonmember_event(
+        event = yield self.event_creation_handler.create_and_send_nonmember_event(
             requester,
             {
                 "type": EventTypes.Redaction,
@@ -803,4 +833,5 @@ def register_servlets(hs, http_server):
     RoomTypingRestServlet(hs).register(http_server)
     SearchRestServlet(hs).register(http_server)
     JoinedRoomsRestServlet(hs).register(http_server)
-    RoomEventContext(hs).register(http_server)
+    RoomEventServlet(hs).register(http_server)
+    RoomEventContextServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index 1f5bc24cc3..77434937ff 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -15,12 +15,13 @@
 
 """This module contains base REST classes for constructing client v1 servlets.
 """
-
-from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
+import logging
 import re
 
-import logging
+from twisted.internet import defer
 
+from synapse.api.errors import InteractiveAuthIncompleteError
+from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
 
 logger = logging.getLogger(__name__)
 
@@ -57,3 +58,37 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit):
         filter_json['room']['timeline']["limit"] = min(
             filter_json['room']['timeline']['limit'],
             filter_timeline_limit)
+
+
+def interactive_auth_handler(orig):
+    """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
+
+    Takes a on_POST method which returns a deferred (errcode, body) response
+    and adds exception handling to turn a InteractiveAuthIncompleteError into
+    a 401 response.
+
+    Normal usage is:
+
+    @interactive_auth_handler
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        # ...
+        yield self.auth_handler.check_auth
+            """
+    def wrapped(*args, **kwargs):
+        res = defer.maybeDeferred(orig, *args, **kwargs)
+        res.addErrback(_catch_incomplete_interactive_auth)
+        return res
+    return wrapped
+
+
+def _catch_incomplete_interactive_auth(f):
+    """helper for interactive_auth_handler
+
+    Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
+
+    Args:
+        f (failure.Failure):
+    """
+    f.trap(InteractiveAuthIncompleteError)
+    return 401, f.value.result
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 726e0a2826..30523995af 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -19,14 +19,15 @@ from twisted.internet import defer
 
 from synapse.api.auth import has_access_token
 from synapse.api.constants import LoginType
-from synapse.api.errors import Codes, LoginError, SynapseError
+from synapse.api.errors import Codes, SynapseError
 from synapse.http.servlet import (
     RestServlet, assert_params_in_request,
     parse_json_object_from_request,
 )
 from synapse.util.async import run_on_reactor
 from synapse.util.msisdn import phone_number_to_msisdn
-from ._base import client_v2_patterns
+from synapse.util.threepids import check_3pid_allowed
+from ._base import client_v2_patterns, interactive_auth_handler
 
 logger = logging.getLogger(__name__)
 
@@ -47,6 +48,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
             'id_server', 'client_secret', 'email', 'send_attempt'
         ])
 
+        if not check_3pid_allowed(self.hs, "email", body['email']):
+            raise SynapseError(
+                403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+            )
+
         existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
             'email', body['email']
         )
@@ -78,6 +84,11 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
 
         msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
 
+        if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+            raise SynapseError(
+                403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+            )
+
         existingUid = yield self.datastore.get_user_id_by_threepid(
             'msisdn', msisdn
         )
@@ -98,56 +109,61 @@ class PasswordRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
         self.datastore = self.hs.get_datastore()
+        self._set_password_handler = hs.get_set_password_handler()
 
+    @interactive_auth_handler
     @defer.inlineCallbacks
     def on_POST(self, request):
-        yield run_on_reactor()
-
         body = parse_json_object_from_request(request)
 
-        authed, result, params, _ = yield self.auth_handler.check_auth([
-            [LoginType.PASSWORD],
-            [LoginType.EMAIL_IDENTITY],
-            [LoginType.MSISDN],
-        ], body, self.hs.get_ip_from_request(request))
+        # there are two possibilities here. Either the user does not have an
+        # access token, and needs to do a password reset; or they have one and
+        # need to validate their identity.
+        #
+        # In the first case, we offer a couple of means of identifying
+        # themselves (email and msisdn, though it's unclear if msisdn actually
+        # works).
+        #
+        # In the second case, we require a password to confirm their identity.
 
-        if not authed:
-            defer.returnValue((401, result))
-
-        user_id = None
-        requester = None
-
-        if LoginType.PASSWORD in result:
-            # if using password, they should also be logged in
+        if has_access_token(request):
             requester = yield self.auth.get_user_by_req(request)
-            user_id = requester.user.to_string()
-            if user_id != result[LoginType.PASSWORD]:
-                raise LoginError(400, "", Codes.UNKNOWN)
-        elif LoginType.EMAIL_IDENTITY in result:
-            threepid = result[LoginType.EMAIL_IDENTITY]
-            if 'medium' not in threepid or 'address' not in threepid:
-                raise SynapseError(500, "Malformed threepid")
-            if threepid['medium'] == 'email':
-                # For emails, transform the address to lowercase.
-                # We store all email addreses as lowercase in the DB.
-                # (See add_threepid in synapse/handlers/auth.py)
-                threepid['address'] = threepid['address'].lower()
-            # if using email, we must know about the email they're authing with!
-            threepid_user_id = yield self.datastore.get_user_id_by_threepid(
-                threepid['medium'], threepid['address']
+            params = yield self.auth_handler.validate_user_via_ui_auth(
+                requester, body, self.hs.get_ip_from_request(request),
             )
-            if not threepid_user_id:
-                raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
-            user_id = threepid_user_id
+            user_id = requester.user.to_string()
         else:
-            logger.error("Auth succeeded but no known type!", result.keys())
-            raise SynapseError(500, "", Codes.UNKNOWN)
+            requester = None
+            result, params, _ = yield self.auth_handler.check_auth(
+                [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]],
+                body, self.hs.get_ip_from_request(request),
+            )
+
+            if LoginType.EMAIL_IDENTITY in result:
+                threepid = result[LoginType.EMAIL_IDENTITY]
+                if 'medium' not in threepid or 'address' not in threepid:
+                    raise SynapseError(500, "Malformed threepid")
+                if threepid['medium'] == 'email':
+                    # For emails, transform the address to lowercase.
+                    # We store all email addreses as lowercase in the DB.
+                    # (See add_threepid in synapse/handlers/auth.py)
+                    threepid['address'] = threepid['address'].lower()
+                # if using email, we must know about the email they're authing with!
+                threepid_user_id = yield self.datastore.get_user_id_by_threepid(
+                    threepid['medium'], threepid['address']
+                )
+                if not threepid_user_id:
+                    raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
+                user_id = threepid_user_id
+            else:
+                logger.error("Auth succeeded but no known type!", result.keys())
+                raise SynapseError(500, "", Codes.UNKNOWN)
 
         if 'new_password' not in params:
             raise SynapseError(400, "", Codes.MISSING_PARAM)
         new_password = params['new_password']
 
-        yield self.auth_handler.set_password(
+        yield self._set_password_handler.set_password(
             user_id, new_password, requester
         )
 
@@ -161,52 +177,32 @@ class DeactivateAccountRestServlet(RestServlet):
     PATTERNS = client_v2_patterns("/account/deactivate$")
 
     def __init__(self, hs):
+        super(DeactivateAccountRestServlet, self).__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
-        super(DeactivateAccountRestServlet, self).__init__()
+        self._deactivate_account_handler = hs.get_deactivate_account_handler()
 
+    @interactive_auth_handler
     @defer.inlineCallbacks
     def on_POST(self, request):
         body = parse_json_object_from_request(request)
 
-        # if the caller provides an access token, it ought to be valid.
-        requester = None
-        if has_access_token(request):
-            requester = yield self.auth.get_user_by_req(
-                request,
-            )  # type: synapse.types.Requester
+        requester = yield self.auth.get_user_by_req(request)
 
         # allow ASes to dectivate their own users
-        if requester and requester.app_service:
-            yield self.auth_handler.deactivate_account(
+        if requester.app_service:
+            yield self._deactivate_account_handler.deactivate_account(
                 requester.user.to_string()
             )
             defer.returnValue((200, {}))
 
-        authed, result, params, _ = yield self.auth_handler.check_auth([
-            [LoginType.PASSWORD],
-        ], body, self.hs.get_ip_from_request(request))
-
-        if not authed:
-            defer.returnValue((401, result))
-
-        if LoginType.PASSWORD in result:
-            user_id = result[LoginType.PASSWORD]
-            # if using password, they should also be logged in
-            if requester is None:
-                raise SynapseError(
-                    400,
-                    "Deactivate account requires an access_token",
-                    errcode=Codes.MISSING_TOKEN
-                )
-            if requester.user.to_string() != user_id:
-                raise LoginError(400, "", Codes.UNKNOWN)
-        else:
-            logger.error("Auth succeeded but no known type!", result.keys())
-            raise SynapseError(500, "", Codes.UNKNOWN)
-
-        yield self.auth_handler.deactivate_account(user_id)
+        yield self.auth_handler.validate_user_via_ui_auth(
+            requester, body, self.hs.get_ip_from_request(request),
+        )
+        yield self._deactivate_account_handler.deactivate_account(
+            requester.user.to_string(),
+        )
         defer.returnValue((200, {}))
 
 
@@ -232,6 +228,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
         if absent:
             raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
 
+        if not check_3pid_allowed(self.hs, "email", body['email']):
+            raise SynapseError(
+                403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+            )
+
         existingUid = yield self.datastore.get_user_id_by_threepid(
             'email', body['email']
         )
@@ -270,6 +271,11 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
 
         msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
 
+        if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+            raise SynapseError(
+                403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+            )
+
         existingUid = yield self.datastore.get_user_id_by_threepid(
             'msisdn', msisdn
         )
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 5321e5abbb..35d58b367a 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -17,9 +17,9 @@ import logging
 
 from twisted.internet import defer
 
-from synapse.api import constants, errors
+from synapse.api import errors
 from synapse.http import servlet
-from ._base import client_v2_patterns
+from ._base import client_v2_patterns, interactive_auth_handler
 
 logger = logging.getLogger(__name__)
 
@@ -60,8 +60,11 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
         self.device_handler = hs.get_device_handler()
         self.auth_handler = hs.get_auth_handler()
 
+    @interactive_auth_handler
     @defer.inlineCallbacks
     def on_POST(self, request):
+        requester = yield self.auth.get_user_by_req(request)
+
         try:
             body = servlet.parse_json_object_from_request(request)
         except errors.SynapseError as e:
@@ -77,14 +80,10 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
                 400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
             )
 
-        authed, result, params, _ = yield self.auth_handler.check_auth([
-            [constants.LoginType.PASSWORD],
-        ], body, self.hs.get_ip_from_request(request))
-
-        if not authed:
-            defer.returnValue((401, result))
+        yield self.auth_handler.validate_user_via_ui_auth(
+            requester, body, self.hs.get_ip_from_request(request),
+        )
 
-        requester = yield self.auth.get_user_by_req(request)
         yield self.device_handler.delete_devices(
             requester.user.to_string(),
             body['devices'],
@@ -115,6 +114,7 @@ class DeviceRestServlet(servlet.RestServlet):
         )
         defer.returnValue((200, device))
 
+    @interactive_auth_handler
     @defer.inlineCallbacks
     def on_DELETE(self, request, device_id):
         requester = yield self.auth.get_user_by_req(request)
@@ -130,19 +130,13 @@ class DeviceRestServlet(servlet.RestServlet):
             else:
                 raise
 
-        authed, result, params, _ = yield self.auth_handler.check_auth([
-            [constants.LoginType.PASSWORD],
-        ], body, self.hs.get_ip_from_request(request))
-
-        if not authed:
-            defer.returnValue((401, result))
-
-        # check that the UI auth matched the access token
-        user_id = result[constants.LoginType.PASSWORD]
-        if user_id != requester.user.to_string():
-            raise errors.AuthError(403, "Invalid auth")
+        yield self.auth_handler.validate_user_via_ui_auth(
+            requester, body, self.hs.get_ip_from_request(request),
+        )
 
-        yield self.device_handler.delete_device(user_id, device_id)
+        yield self.device_handler.delete_device(
+            requester.user.to_string(), device_id,
+        )
         defer.returnValue((200, {}))
 
     @defer.inlineCallbacks
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 089ec71c81..f762dbfa9a 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -38,7 +38,7 @@ class GroupServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request, group_id):
-        requester = yield self.auth.get_user_by_req(request)
+        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
         group_description = yield self.groups_handler.get_group_profile(
@@ -74,7 +74,7 @@ class GroupSummaryServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request, group_id):
-        requester = yield self.auth.get_user_by_req(request)
+        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
         get_group_summary = yield self.groups_handler.get_group_summary(
@@ -148,7 +148,7 @@ class GroupCategoryServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request, group_id, category_id):
-        requester = yield self.auth.get_user_by_req(request)
+        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
         category = yield self.groups_handler.get_group_category(
@@ -200,7 +200,7 @@ class GroupCategoriesServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request, group_id):
-        requester = yield self.auth.get_user_by_req(request)
+        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
         category = yield self.groups_handler.get_group_categories(
@@ -225,7 +225,7 @@ class GroupRoleServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request, group_id, role_id):
-        requester = yield self.auth.get_user_by_req(request)
+        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
         category = yield self.groups_handler.get_group_role(
@@ -277,7 +277,7 @@ class GroupRolesServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request, group_id):
-        requester = yield self.auth.get_user_by_req(request)
+        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
         category = yield self.groups_handler.get_group_roles(
@@ -348,7 +348,7 @@ class GroupRoomServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request, group_id):
-        requester = yield self.auth.get_user_by_req(request)
+        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
         result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id)
@@ -369,7 +369,7 @@ class GroupUsersServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request, group_id):
-        requester = yield self.auth.get_user_by_req(request)
+        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
         result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id)
@@ -672,7 +672,7 @@ class PublicisedGroupsForUserServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id):
-        yield self.auth.get_user_by_req(request)
+        yield self.auth.get_user_by_req(request, allow_guest=True)
 
         result = yield self.groups_handler.get_publicised_groups_for_user(
             user_id
@@ -697,7 +697,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_POST(self, request):
-        yield self.auth.get_user_by_req(request)
+        yield self.auth.get_user_by_req(request, allow_guest=True)
 
         content = parse_json_object_from_request(request)
         user_ids = content["user_ids"]
@@ -724,7 +724,7 @@ class GroupsForUserServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request):
-        requester = yield self.auth.get_user_by_req(request)
+        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
         result = yield self.groups_handler.get_joined_groups(requester_user_id)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 9e2f7308ce..c6f4680a76 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -26,8 +26,9 @@ from synapse.http.servlet import (
     RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
 )
 from synapse.util.msisdn import phone_number_to_msisdn
+from synapse.util.threepids import check_3pid_allowed
 
-from ._base import client_v2_patterns
+from ._base import client_v2_patterns, interactive_auth_handler
 
 import logging
 import hmac
@@ -70,6 +71,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
             'id_server', 'client_secret', 'email', 'send_attempt'
         ])
 
+        if not check_3pid_allowed(self.hs, "email", body['email']):
+            raise SynapseError(
+                403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+            )
+
         existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
             'email', body['email']
         )
@@ -105,6 +111,11 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
 
         msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
 
+        if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+            raise SynapseError(
+                403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+            )
+
         existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
             'msisdn', msisdn
         )
@@ -176,6 +187,7 @@ class RegisterRestServlet(RestServlet):
         self.device_handler = hs.get_device_handler()
         self.macaroon_gen = hs.get_macaroon_generator()
 
+    @interactive_auth_handler
     @defer.inlineCallbacks
     def on_POST(self, request):
         yield run_on_reactor()
@@ -304,34 +316,66 @@ class RegisterRestServlet(RestServlet):
         if 'x_show_msisdn' in body and body['x_show_msisdn']:
             show_msisdn = True
 
+        # FIXME: need a better error than "no auth flow found" for scenarios
+        # where we required 3PID for registration but the user didn't give one
+        require_email = 'email' in self.hs.config.registrations_require_3pid
+        require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
+
+        flows = []
         if self.hs.config.enable_registration_captcha:
-            flows = [
-                [LoginType.RECAPTCHA],
-                [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
-            ]
+            # only support 3PIDless registration if no 3PIDs are required
+            if not require_email and not require_msisdn:
+                flows.extend([[LoginType.RECAPTCHA]])
+            # only support the email-only flow if we don't require MSISDN 3PIDs
+            if not require_msisdn:
+                flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])
+
             if show_msisdn:
+                # only support the MSISDN-only flow if we don't require email 3PIDs
+                if not require_email:
+                    flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]])
+                # always let users provide both MSISDN & email
                 flows.extend([
-                    [LoginType.MSISDN, LoginType.RECAPTCHA],
                     [LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
                 ])
         else:
-            flows = [
-                [LoginType.DUMMY],
-                [LoginType.EMAIL_IDENTITY],
-            ]
+            # only support 3PIDless registration if no 3PIDs are required
+            if not require_email and not require_msisdn:
+                flows.extend([[LoginType.DUMMY]])
+            # only support the email-only flow if we don't require MSISDN 3PIDs
+            if not require_msisdn:
+                flows.extend([[LoginType.EMAIL_IDENTITY]])
+
             if show_msisdn:
+                # only support the MSISDN-only flow if we don't require email 3PIDs
+                if not require_email or require_msisdn:
+                    flows.extend([[LoginType.MSISDN]])
+                # always let users provide both MSISDN & email
                 flows.extend([
-                    [LoginType.MSISDN],
-                    [LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
+                    [LoginType.MSISDN, LoginType.EMAIL_IDENTITY]
                 ])
 
-        authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
+        auth_result, params, session_id = yield self.auth_handler.check_auth(
             flows, body, self.hs.get_ip_from_request(request)
         )
 
-        if not authed:
-            defer.returnValue((401, auth_result))
-            return
+        # Check that we're not trying to register a denied 3pid.
+        #
+        # the user-facing checks will probably already have happened in
+        # /register/email/requestToken when we requested a 3pid, but that's not
+        # guaranteed.
+
+        if auth_result:
+            for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
+                if login_type in auth_result:
+                    medium = auth_result[login_type]['medium']
+                    address = auth_result[login_type]['address']
+
+                    if not check_3pid_allowed(self.hs, medium, address):
+                        raise SynapseError(
+                            403, "Third party identifier is not allowed",
+                            Codes.THREEPID_DENIED,
+                        )
 
         if registered_user_id is not None:
             logger.info(
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index cc2842aa72..17e6079cba 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -93,6 +93,7 @@ class RemoteKey(Resource):
         self.store = hs.get_datastore()
         self.version_string = hs.version_string
         self.clock = hs.get_clock()
+        self.federation_domain_whitelist = hs.config.federation_domain_whitelist
 
     def render_GET(self, request):
         self.async_render_GET(request)
@@ -137,6 +138,13 @@ class RemoteKey(Resource):
         logger.info("Handling query for keys %r", query)
         store_queries = []
         for server_name, key_ids in query.items():
+            if (
+                self.federation_domain_whitelist is not None and
+                server_name not in self.federation_domain_whitelist
+            ):
+                logger.debug("Federation denied with %s", server_name)
+                continue
+
             if not key_ids:
                 key_ids = (None,)
             for key_id in key_ids:
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 95fa95fce3..e7ac01da01 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -70,38 +70,11 @@ def respond_with_file(request, media_type, file_path,
     logger.debug("Responding with %r", file_path)
 
     if os.path.isfile(file_path):
-        request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
-        if upload_name:
-            if is_ascii(upload_name):
-                request.setHeader(
-                    b"Content-Disposition",
-                    b"inline; filename=%s" % (
-                        urllib.quote(upload_name.encode("utf-8")),
-                    ),
-                )
-            else:
-                request.setHeader(
-                    b"Content-Disposition",
-                    b"inline; filename*=utf-8''%s" % (
-                        urllib.quote(upload_name.encode("utf-8")),
-                    ),
-                )
-
-        # cache for at least a day.
-        # XXX: we might want to turn this off for data we don't want to
-        # recommend caching as it's sensitive or private - or at least
-        # select private. don't bother setting Expires as all our
-        # clients are smart enough to be happy with Cache-Control
-        request.setHeader(
-            b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
-        )
         if file_size is None:
             stat = os.stat(file_path)
             file_size = stat.st_size
 
-        request.setHeader(
-            b"Content-Length", b"%d" % (file_size,)
-        )
+        add_file_headers(request, media_type, file_size, upload_name)
 
         with open(file_path, "rb") as f:
             yield logcontext.make_deferred_yieldable(
@@ -111,3 +84,118 @@ def respond_with_file(request, media_type, file_path,
         finish_request(request)
     else:
         respond_404(request)
+
+
+def add_file_headers(request, media_type, file_size, upload_name):
+    """Adds the correct response headers in preparation for responding with the
+    media.
+
+    Args:
+        request (twisted.web.http.Request)
+        media_type (str): The media/content type.
+        file_size (int): Size in bytes of the media, if known.
+        upload_name (str): The name of the requested file, if any.
+    """
+    request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
+    if upload_name:
+        if is_ascii(upload_name):
+            request.setHeader(
+                b"Content-Disposition",
+                b"inline; filename=%s" % (
+                    urllib.quote(upload_name.encode("utf-8")),
+                ),
+            )
+        else:
+            request.setHeader(
+                b"Content-Disposition",
+                b"inline; filename*=utf-8''%s" % (
+                    urllib.quote(upload_name.encode("utf-8")),
+                ),
+            )
+
+    # cache for at least a day.
+    # XXX: we might want to turn this off for data we don't want to
+    # recommend caching as it's sensitive or private - or at least
+    # select private. don't bother setting Expires as all our
+    # clients are smart enough to be happy with Cache-Control
+    request.setHeader(
+        b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
+    )
+
+    request.setHeader(
+        b"Content-Length", b"%d" % (file_size,)
+    )
+
+
+@defer.inlineCallbacks
+def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
+    """Responds to the request with given responder. If responder is None then
+    returns 404.
+
+    Args:
+        request (twisted.web.http.Request)
+        responder (Responder|None)
+        media_type (str): The media/content type.
+        file_size (int|None): Size in bytes of the media. If not known it should be None
+        upload_name (str|None): The name of the requested file, if any.
+    """
+    if not responder:
+        respond_404(request)
+        return
+
+    add_file_headers(request, media_type, file_size, upload_name)
+    with responder:
+        yield responder.write_to_consumer(request)
+    finish_request(request)
+
+
+class Responder(object):
+    """Represents a response that can be streamed to the requester.
+
+    Responder is a context manager which *must* be used, so that any resources
+    held can be cleaned up.
+    """
+    def write_to_consumer(self, consumer):
+        """Stream response into consumer
+
+        Args:
+            consumer (IConsumer)
+
+        Returns:
+            Deferred: Resolves once the response has finished being written
+        """
+        pass
+
+    def __enter__(self):
+        pass
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        pass
+
+
+class FileInfo(object):
+    """Details about a requested/uploaded file.
+
+    Attributes:
+        server_name (str): The server name where the media originated from,
+            or None if local.
+        file_id (str): The local ID of the file. For local files this is the
+            same as the media_id
+        url_cache (bool): If the file is for the url preview cache
+        thumbnail (bool): Whether the file is a thumbnail or not.
+        thumbnail_width (int)
+        thumbnail_height (int)
+        thumbnail_method (str)
+        thumbnail_type (str): Content type of thumbnail, e.g. image/png
+    """
+    def __init__(self, server_name, file_id, url_cache=False,
+                 thumbnail=False, thumbnail_width=None, thumbnail_height=None,
+                 thumbnail_method=None, thumbnail_type=None):
+        self.server_name = server_name
+        self.file_id = file_id
+        self.url_cache = url_cache
+        self.thumbnail = thumbnail
+        self.thumbnail_width = thumbnail_width
+        self.thumbnail_height = thumbnail_height
+        self.thumbnail_method = thumbnail_method
+        self.thumbnail_type = thumbnail_type
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 6879249c8a..fe7e17596f 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import synapse.http.servlet
 
-from ._base import parse_media_id, respond_with_file, respond_404
+from ._base import parse_media_id, respond_404
 from twisted.web.resource import Resource
 from synapse.http.server import request_handler, set_cors_headers
 
@@ -32,12 +32,12 @@ class DownloadResource(Resource):
     def __init__(self, hs, media_repo):
         Resource.__init__(self)
 
-        self.filepaths = media_repo.filepaths
         self.media_repo = media_repo
         self.server_name = hs.hostname
-        self.store = hs.get_datastore()
-        self.version_string = hs.version_string
+
+        # Both of these are expected by @request_handler()
         self.clock = hs.get_clock()
+        self.version_string = hs.version_string
 
     def render_GET(self, request):
         self._async_render_GET(request)
@@ -57,59 +57,16 @@ class DownloadResource(Resource):
         )
         server_name, media_id, name = parse_media_id(request)
         if server_name == self.server_name:
-            yield self._respond_local_file(request, media_id, name)
+            yield self.media_repo.get_local_media(request, media_id, name)
         else:
-            yield self._respond_remote_file(
-                request, server_name, media_id, name
-            )
-
-    @defer.inlineCallbacks
-    def _respond_local_file(self, request, media_id, name):
-        media_info = yield self.store.get_local_media(media_id)
-        if not media_info or media_info["quarantined_by"]:
-            respond_404(request)
-            return
-
-        media_type = media_info["media_type"]
-        media_length = media_info["media_length"]
-        upload_name = name if name else media_info["upload_name"]
-        if media_info["url_cache"]:
-            # TODO: Check the file still exists, if it doesn't we can redownload
-            # it from the url `media_info["url_cache"]`
-            file_path = self.filepaths.url_cache_filepath(media_id)
-        else:
-            file_path = self.filepaths.local_media_filepath(media_id)
-
-        yield respond_with_file(
-            request, media_type, file_path, media_length,
-            upload_name=upload_name,
-        )
-
-    @defer.inlineCallbacks
-    def _respond_remote_file(self, request, server_name, media_id, name):
-        # don't forward requests for remote media if allow_remote is false
-        allow_remote = synapse.http.servlet.parse_boolean(
-            request, "allow_remote", default=True)
-        if not allow_remote:
-            logger.info(
-                "Rejecting request for remote media %s/%s due to allow_remote",
-                server_name, media_id,
-            )
-            respond_404(request)
-            return
-
-        media_info = yield self.media_repo.get_remote_media(server_name, media_id)
-
-        media_type = media_info["media_type"]
-        media_length = media_info["media_length"]
-        filesystem_id = media_info["filesystem_id"]
-        upload_name = name if name else media_info["upload_name"]
-
-        file_path = self.filepaths.remote_media_filepath(
-            server_name, filesystem_id
-        )
-
-        yield respond_with_file(
-            request, media_type, file_path, media_length,
-            upload_name=upload_name,
-        )
+            allow_remote = synapse.http.servlet.parse_boolean(
+                request, "allow_remote", default=True)
+            if not allow_remote:
+                logger.info(
+                    "Rejecting request for remote media %s/%s due to allow_remote",
+                    server_name, media_id,
+                )
+                respond_404(request)
+                return
+
+            yield self.media_repo.get_remote_media(request, server_name, media_id, name)
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index eed9056a2f..bb79599379 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -18,6 +19,7 @@ import twisted.internet.error
 import twisted.web.http
 from twisted.web.resource import Resource
 
+from ._base import respond_404, FileInfo, respond_with_responder
 from .upload_resource import UploadResource
 from .download_resource import DownloadResource
 from .thumbnail_resource import ThumbnailResource
@@ -25,15 +27,18 @@ from .identicon_resource import IdenticonResource
 from .preview_url_resource import PreviewUrlResource
 from .filepath import MediaFilePaths
 from .thumbnailer import Thumbnailer
+from .storage_provider import StorageProviderWrapper
+from .media_storage import MediaStorage
 
 from synapse.http.matrixfederationclient import MatrixFederationHttpClient
 from synapse.util.stringutils import random_string
-from synapse.api.errors import SynapseError, HttpResponseException, \
-    NotFoundError
+from synapse.api.errors import (
+    SynapseError, HttpResponseException, NotFoundError, FederationDeniedError,
+)
 
 from synapse.util.async import Linearizer
 from synapse.util.stringutils import is_ascii
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+from synapse.util.logcontext import make_deferred_yieldable
 from synapse.util.retryutils import NotRetryingDestination
 
 import os
@@ -47,7 +52,7 @@ import urlparse
 logger = logging.getLogger(__name__)
 
 
-UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
+UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
 
 
 class MediaRepository(object):
@@ -63,96 +68,62 @@ class MediaRepository(object):
         self.primary_base_path = hs.config.media_store_path
         self.filepaths = MediaFilePaths(self.primary_base_path)
 
-        self.backup_base_path = hs.config.backup_media_store_path
-
-        self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store
-
         self.dynamic_thumbnails = hs.config.dynamic_thumbnails
         self.thumbnail_requirements = hs.config.thumbnail_requirements
 
         self.remote_media_linearizer = Linearizer(name="media_remote")
 
         self.recently_accessed_remotes = set()
+        self.recently_accessed_locals = set()
 
-        self.clock.looping_call(
-            self._update_recently_accessed_remotes,
-            UPDATE_RECENTLY_ACCESSED_REMOTES_TS
-        )
+        self.federation_domain_whitelist = hs.config.federation_domain_whitelist
 
-    @defer.inlineCallbacks
-    def _update_recently_accessed_remotes(self):
-        media = self.recently_accessed_remotes
-        self.recently_accessed_remotes = set()
-
-        yield self.store.update_cached_last_access_time(
-            media, self.clock.time_msec()
-        )
+        # List of StorageProviders where we should search for media and
+        # potentially upload to.
+        storage_providers = []
 
-    @staticmethod
-    def _makedirs(filepath):
-        dirname = os.path.dirname(filepath)
-        if not os.path.exists(dirname):
-            os.makedirs(dirname)
+        for clz, provider_config, wrapper_config in hs.config.media_storage_providers:
+            backend = clz(hs, provider_config)
+            provider = StorageProviderWrapper(
+                backend,
+                store_local=wrapper_config.store_local,
+                store_remote=wrapper_config.store_remote,
+                store_synchronous=wrapper_config.store_synchronous,
+            )
+            storage_providers.append(provider)
 
-    @staticmethod
-    def _write_file_synchronously(source, fname):
-        """Write `source` to the path `fname` synchronously. Should be called
-        from a thread.
+        self.media_storage = MediaStorage(
+            self.primary_base_path, self.filepaths, storage_providers,
+        )
 
-        Args:
-            source: A file like object to be written
-            fname (str): Path to write to
-        """
-        MediaRepository._makedirs(fname)
-        source.seek(0)  # Ensure we read from the start of the file
-        with open(fname, "wb") as f:
-            shutil.copyfileobj(source, f)
+        self.clock.looping_call(
+            self._update_recently_accessed,
+            UPDATE_RECENTLY_ACCESSED_TS,
+        )
 
     @defer.inlineCallbacks
-    def write_to_file_and_backup(self, source, path):
-        """Write `source` to the on disk media store, and also the backup store
-        if configured.
-
-        Args:
-            source: A file like object that should be written
-            path (str): Relative path to write file to
-
-        Returns:
-            Deferred[str]: the file path written to in the primary media store
-        """
-        fname = os.path.join(self.primary_base_path, path)
-
-        # Write to the main repository
-        yield make_deferred_yieldable(threads.deferToThread(
-            self._write_file_synchronously, source, fname,
-        ))
+    def _update_recently_accessed(self):
+        remote_media = self.recently_accessed_remotes
+        self.recently_accessed_remotes = set()
 
-        # Write to backup repository
-        yield self.copy_to_backup(path)
+        local_media = self.recently_accessed_locals
+        self.recently_accessed_locals = set()
 
-        defer.returnValue(fname)
+        yield self.store.update_cached_last_access_time(
+            local_media, remote_media, self.clock.time_msec()
+        )
 
-    @defer.inlineCallbacks
-    def copy_to_backup(self, path):
-        """Copy a file from the primary to backup media store, if configured.
+    def mark_recently_accessed(self, server_name, media_id):
+        """Mark the given media as recently accessed.
 
         Args:
-            path(str): Relative path to write file to
+            server_name (str|None): Origin server of media, or None if local
+            media_id (str): The media ID of the content
         """
-        if self.backup_base_path:
-            primary_fname = os.path.join(self.primary_base_path, path)
-            backup_fname = os.path.join(self.backup_base_path, path)
-
-            # We can either wait for successful writing to the backup repository
-            # or write in the background and immediately return
-            if self.synchronous_backup_media_store:
-                yield make_deferred_yieldable(threads.deferToThread(
-                    shutil.copyfile, primary_fname, backup_fname,
-                ))
-            else:
-                preserve_fn(threads.deferToThread)(
-                    shutil.copyfile, primary_fname, backup_fname,
-                )
+        if server_name:
+            self.recently_accessed_remotes.add((server_name, media_id))
+        else:
+            self.recently_accessed_locals.add(media_id)
 
     @defer.inlineCallbacks
     def create_content(self, media_type, upload_name, content, content_length,
@@ -171,10 +142,13 @@ class MediaRepository(object):
         """
         media_id = random_string(24)
 
-        fname = yield self.write_to_file_and_backup(
-            content, self.filepaths.local_media_filepath_rel(media_id)
+        file_info = FileInfo(
+            server_name=None,
+            file_id=media_id,
         )
 
+        fname = yield self.media_storage.store_file(content, file_info)
+
         logger.info("Stored local media in file %r", fname)
 
         yield self.store.store_local_media(
@@ -185,134 +159,275 @@ class MediaRepository(object):
             media_length=content_length,
             user_id=auth_user,
         )
-        media_info = {
-            "media_type": media_type,
-            "media_length": content_length,
-        }
 
-        yield self._generate_thumbnails(None, media_id, media_info)
+        yield self._generate_thumbnails(
+            None, media_id, media_id, media_type,
+        )
 
         defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
 
     @defer.inlineCallbacks
-    def get_remote_media(self, server_name, media_id):
+    def get_local_media(self, request, media_id, name):
+        """Responds to reqests for local media, if exists, or returns 404.
+
+        Args:
+            request(twisted.web.http.Request)
+            media_id (str): The media ID of the content. (This is the same as
+                the file_id for local content.)
+            name (str|None): Optional name that, if specified, will be used as
+                the filename in the Content-Disposition header of the response.
+
+        Returns:
+            Deferred: Resolves once a response has successfully been written
+                to request
+        """
+        media_info = yield self.store.get_local_media(media_id)
+        if not media_info or media_info["quarantined_by"]:
+            respond_404(request)
+            return
+
+        self.mark_recently_accessed(None, media_id)
+
+        media_type = media_info["media_type"]
+        media_length = media_info["media_length"]
+        upload_name = name if name else media_info["upload_name"]
+        url_cache = media_info["url_cache"]
+
+        file_info = FileInfo(
+            None, media_id,
+            url_cache=url_cache,
+        )
+
+        responder = yield self.media_storage.fetch_media(file_info)
+        yield respond_with_responder(
+            request, responder, media_type, media_length, upload_name,
+        )
+
+    @defer.inlineCallbacks
+    def get_remote_media(self, request, server_name, media_id, name):
+        """Respond to requests for remote media.
+
+        Args:
+            request(twisted.web.http.Request)
+            server_name (str): Remote server_name where the media originated.
+            media_id (str): The media ID of the content (as defined by the
+                remote server).
+            name (str|None): Optional name that, if specified, will be used as
+                the filename in the Content-Disposition header of the response.
+
+        Returns:
+            Deferred: Resolves once a response has successfully been written
+                to request
+        """
+        if (
+            self.federation_domain_whitelist is not None and
+            server_name not in self.federation_domain_whitelist
+        ):
+            raise FederationDeniedError(server_name)
+
+        self.mark_recently_accessed(server_name, media_id)
+
+        # We linearize here to ensure that we don't try and download remote
+        # media multiple times concurrently
+        key = (server_name, media_id)
+        with (yield self.remote_media_linearizer.queue(key)):
+            responder, media_info = yield self._get_remote_media_impl(
+                server_name, media_id,
+            )
+
+        # We deliberately stream the file outside the lock
+        if responder:
+            media_type = media_info["media_type"]
+            media_length = media_info["media_length"]
+            upload_name = name if name else media_info["upload_name"]
+            yield respond_with_responder(
+                request, responder, media_type, media_length, upload_name,
+            )
+        else:
+            respond_404(request)
+
+    @defer.inlineCallbacks
+    def get_remote_media_info(self, server_name, media_id):
+        """Gets the media info associated with the remote file, downloading
+        if necessary.
+
+        Args:
+            server_name (str): Remote server_name where the media originated.
+            media_id (str): The media ID of the content (as defined by the
+                remote server).
+
+        Returns:
+            Deferred[dict]: The media_info of the file
+        """
+        if (
+            self.federation_domain_whitelist is not None and
+            server_name not in self.federation_domain_whitelist
+        ):
+            raise FederationDeniedError(server_name)
+
+        # We linearize here to ensure that we don't try and download remote
+        # media multiple times concurrently
         key = (server_name, media_id)
         with (yield self.remote_media_linearizer.queue(key)):
-            media_info = yield self._get_remote_media_impl(server_name, media_id)
+            responder, media_info = yield self._get_remote_media_impl(
+                server_name, media_id,
+            )
+
+        # Ensure we actually use the responder so that it releases resources
+        if responder:
+            with responder:
+                pass
+
         defer.returnValue(media_info)
 
     @defer.inlineCallbacks
     def _get_remote_media_impl(self, server_name, media_id):
+        """Looks for media in local cache, if not there then attempt to
+        download from remote server.
+
+        Args:
+            server_name (str): Remote server_name where the media originated.
+            media_id (str): The media ID of the content (as defined by the
+                remote server).
+
+        Returns:
+            Deferred[(Responder, media_info)]
+        """
         media_info = yield self.store.get_cached_remote_media(
             server_name, media_id
         )
-        if not media_info:
-            media_info = yield self._download_remote_file(
-                server_name, media_id
-            )
-        elif media_info["quarantined_by"]:
-            raise NotFoundError()
+
+        # file_id is the ID we use to track the file locally. If we've already
+        # seen the file then reuse the existing ID, otherwise genereate a new
+        # one.
+        if media_info:
+            file_id = media_info["filesystem_id"]
         else:
-            self.recently_accessed_remotes.add((server_name, media_id))
-            yield self.store.update_cached_last_access_time(
-                [(server_name, media_id)], self.clock.time_msec()
-            )
-        defer.returnValue(media_info)
+            file_id = random_string(24)
+
+        file_info = FileInfo(server_name, file_id)
+
+        # If we have an entry in the DB, try and look for it
+        if media_info:
+            if media_info["quarantined_by"]:
+                logger.info("Media is quarantined")
+                raise NotFoundError()
+
+            responder = yield self.media_storage.fetch_media(file_info)
+            if responder:
+                defer.returnValue((responder, media_info))
+
+        # Failed to find the file anywhere, lets download it.
+
+        media_info = yield self._download_remote_file(
+            server_name, media_id, file_id
+        )
+
+        responder = yield self.media_storage.fetch_media(file_info)
+        defer.returnValue((responder, media_info))
 
     @defer.inlineCallbacks
-    def _download_remote_file(self, server_name, media_id):
-        file_id = random_string(24)
+    def _download_remote_file(self, server_name, media_id, file_id):
+        """Attempt to download the remote file from the given server name,
+        using the given file_id as the local id.
+
+        Args:
+            server_name (str): Originating server
+            media_id (str): The media ID of the content (as defined by the
+                remote server). This is different than the file_id, which is
+                locally generated.
+            file_id (str): Local file ID
+
+        Returns:
+            Deferred[MediaInfo]
+        """
 
-        fpath = self.filepaths.remote_media_filepath_rel(
-            server_name, file_id
+        file_info = FileInfo(
+            server_name=server_name,
+            file_id=file_id,
         )
-        fname = os.path.join(self.primary_base_path, fpath)
-        self._makedirs(fname)
 
-        try:
-            with open(fname, "wb") as f:
-                request_path = "/".join((
-                    "/_matrix/media/v1/download", server_name, media_id,
-                ))
+        with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+            request_path = "/".join((
+                "/_matrix/media/v1/download", server_name, media_id,
+            ))
+            try:
+                length, headers = yield self.client.get_file(
+                    server_name, request_path, output_stream=f,
+                    max_size=self.max_upload_size, args={
+                        # tell the remote server to 404 if it doesn't
+                        # recognise the server_name, to make sure we don't
+                        # end up with a routing loop.
+                        "allow_remote": "false",
+                    }
+                )
+            except twisted.internet.error.DNSLookupError as e:
+                logger.warn("HTTP error fetching remote media %s/%s: %r",
+                            server_name, media_id, e)
+                raise NotFoundError()
+
+            except HttpResponseException as e:
+                logger.warn("HTTP error fetching remote media %s/%s: %s",
+                            server_name, media_id, e.response)
+                if e.code == twisted.web.http.NOT_FOUND:
+                    raise SynapseError.from_http_response_exception(e)
+                raise SynapseError(502, "Failed to fetch remote media")
+
+            except SynapseError:
+                logger.exception("Failed to fetch remote media %s/%s",
+                                 server_name, media_id)
+                raise
+            except NotRetryingDestination:
+                logger.warn("Not retrying destination %r", server_name)
+                raise SynapseError(502, "Failed to fetch remote media")
+            except Exception:
+                logger.exception("Failed to fetch remote media %s/%s",
+                                 server_name, media_id)
+                raise SynapseError(502, "Failed to fetch remote media")
+
+            yield finish()
+
+        media_type = headers["Content-Type"][0]
+
+        time_now_ms = self.clock.time_msec()
+
+        content_disposition = headers.get("Content-Disposition", None)
+        if content_disposition:
+            _, params = cgi.parse_header(content_disposition[0],)
+            upload_name = None
+
+            # First check if there is a valid UTF-8 filename
+            upload_name_utf8 = params.get("filename*", None)
+            if upload_name_utf8:
+                if upload_name_utf8.lower().startswith("utf-8''"):
+                    upload_name = upload_name_utf8[7:]
+
+            # If there isn't check for an ascii name.
+            if not upload_name:
+                upload_name_ascii = params.get("filename", None)
+                if upload_name_ascii and is_ascii(upload_name_ascii):
+                    upload_name = upload_name_ascii
+
+            if upload_name:
+                upload_name = urlparse.unquote(upload_name)
                 try:
-                    length, headers = yield self.client.get_file(
-                        server_name, request_path, output_stream=f,
-                        max_size=self.max_upload_size, args={
-                            # tell the remote server to 404 if it doesn't
-                            # recognise the server_name, to make sure we don't
-                            # end up with a routing loop.
-                            "allow_remote": "false",
-                        }
-                    )
-                except twisted.internet.error.DNSLookupError as e:
-                    logger.warn("HTTP error fetching remote media %s/%s: %r",
-                                server_name, media_id, e)
-                    raise NotFoundError()
-
-                except HttpResponseException as e:
-                    logger.warn("HTTP error fetching remote media %s/%s: %s",
-                                server_name, media_id, e.response)
-                    if e.code == twisted.web.http.NOT_FOUND:
-                        raise SynapseError.from_http_response_exception(e)
-                    raise SynapseError(502, "Failed to fetch remote media")
-
-                except SynapseError:
-                    logger.exception("Failed to fetch remote media %s/%s",
-                                     server_name, media_id)
-                    raise
-                except NotRetryingDestination:
-                    logger.warn("Not retrying destination %r", server_name)
-                    raise SynapseError(502, "Failed to fetch remote media")
-                except Exception:
-                    logger.exception("Failed to fetch remote media %s/%s",
-                                     server_name, media_id)
-                    raise SynapseError(502, "Failed to fetch remote media")
-
-            yield self.copy_to_backup(fpath)
-
-            media_type = headers["Content-Type"][0]
-            time_now_ms = self.clock.time_msec()
-
-            content_disposition = headers.get("Content-Disposition", None)
-            if content_disposition:
-                _, params = cgi.parse_header(content_disposition[0],)
-                upload_name = None
-
-                # First check if there is a valid UTF-8 filename
-                upload_name_utf8 = params.get("filename*", None)
-                if upload_name_utf8:
-                    if upload_name_utf8.lower().startswith("utf-8''"):
-                        upload_name = upload_name_utf8[7:]
-
-                # If there isn't check for an ascii name.
-                if not upload_name:
-                    upload_name_ascii = params.get("filename", None)
-                    if upload_name_ascii and is_ascii(upload_name_ascii):
-                        upload_name = upload_name_ascii
-
-                if upload_name:
-                    upload_name = urlparse.unquote(upload_name)
-                    try:
-                        upload_name = upload_name.decode("utf-8")
-                    except UnicodeDecodeError:
-                        upload_name = None
-            else:
-                upload_name = None
-
-            logger.info("Stored remote media in file %r", fname)
-
-            yield self.store.store_cached_remote_media(
-                origin=server_name,
-                media_id=media_id,
-                media_type=media_type,
-                time_now_ms=self.clock.time_msec(),
-                upload_name=upload_name,
-                media_length=length,
-                filesystem_id=file_id,
-            )
-        except Exception:
-            os.remove(fname)
-            raise
+                    upload_name = upload_name.decode("utf-8")
+                except UnicodeDecodeError:
+                    upload_name = None
+        else:
+            upload_name = None
+
+        logger.info("Stored remote media in file %r", fname)
+
+        yield self.store.store_cached_remote_media(
+            origin=server_name,
+            media_id=media_id,
+            media_type=media_type,
+            time_now_ms=self.clock.time_msec(),
+            upload_name=upload_name,
+            media_length=length,
+            filesystem_id=file_id,
+        )
 
         media_info = {
             "media_type": media_type,
@@ -323,7 +438,7 @@ class MediaRepository(object):
         }
 
         yield self._generate_thumbnails(
-            server_name, media_id, media_info
+            server_name, media_id, file_id, media_type,
         )
 
         defer.returnValue(media_info)
@@ -357,8 +472,10 @@ class MediaRepository(object):
 
     @defer.inlineCallbacks
     def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
-                                       t_method, t_type):
-        input_path = self.filepaths.local_media_filepath(media_id)
+                                       t_method, t_type, url_cache):
+        input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
+            None, media_id, url_cache=url_cache,
+        ))
 
         thumbnailer = Thumbnailer(input_path)
         t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@@ -368,11 +485,19 @@ class MediaRepository(object):
 
         if t_byte_source:
             try:
-                output_path = yield self.write_to_file_and_backup(
-                    t_byte_source,
-                    self.filepaths.local_media_thumbnail_rel(
-                        media_id, t_width, t_height, t_type, t_method
-                    )
+                file_info = FileInfo(
+                    server_name=None,
+                    file_id=media_id,
+                    url_cache=url_cache,
+                    thumbnail=True,
+                    thumbnail_width=t_width,
+                    thumbnail_height=t_height,
+                    thumbnail_method=t_method,
+                    thumbnail_type=t_type,
+                )
+
+                output_path = yield self.media_storage.store_file(
+                    t_byte_source, file_info,
                 )
             finally:
                 t_byte_source.close()
@@ -390,7 +515,9 @@ class MediaRepository(object):
     @defer.inlineCallbacks
     def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
                                         t_width, t_height, t_method, t_type):
-        input_path = self.filepaths.remote_media_filepath(server_name, file_id)
+        input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
+            server_name, file_id, url_cache=False,
+        ))
 
         thumbnailer = Thumbnailer(input_path)
         t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@@ -400,11 +527,18 @@ class MediaRepository(object):
 
         if t_byte_source:
             try:
-                output_path = yield self.write_to_file_and_backup(
-                    t_byte_source,
-                    self.filepaths.remote_media_thumbnail_rel(
-                        server_name, file_id, t_width, t_height, t_type, t_method
-                    )
+                file_info = FileInfo(
+                    server_name=server_name,
+                    file_id=media_id,
+                    thumbnail=True,
+                    thumbnail_width=t_width,
+                    thumbnail_height=t_height,
+                    thumbnail_method=t_method,
+                    thumbnail_type=t_type,
+                )
+
+                output_path = yield self.media_storage.store_file(
+                    t_byte_source, file_info,
                 )
             finally:
                 t_byte_source.close()
@@ -421,31 +555,29 @@ class MediaRepository(object):
             defer.returnValue(output_path)
 
     @defer.inlineCallbacks
-    def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False):
+    def _generate_thumbnails(self, server_name, media_id, file_id, media_type,
+                             url_cache=False):
         """Generate and store thumbnails for an image.
 
         Args:
-            server_name(str|None): The server name if remote media, else None if local
-            media_id(str)
-            media_info(dict)
-            url_cache(bool): If we are thumbnailing images downloaded for the URL cache,
+            server_name (str|None): The server name if remote media, else None if local
+            media_id (str): The media ID of the content. (This is the same as
+                the file_id for local content)
+            file_id (str): Local file ID
+            media_type (str): The content type of the file
+            url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
                 used exclusively by the url previewer
 
         Returns:
             Deferred[dict]: Dict with "width" and "height" keys of original image
         """
-        media_type = media_info["media_type"]
-        file_id = media_info.get("filesystem_id")
         requirements = self._get_thumbnail_requirements(media_type)
         if not requirements:
             return
 
-        if server_name:
-            input_path = self.filepaths.remote_media_filepath(server_name, file_id)
-        elif url_cache:
-            input_path = self.filepaths.url_cache_filepath(media_id)
-        else:
-            input_path = self.filepaths.local_media_filepath(media_id)
+        input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
+            server_name, file_id, url_cache=url_cache,
+        ))
 
         thumbnailer = Thumbnailer(input_path)
         m_width = thumbnailer.width
@@ -472,20 +604,6 @@ class MediaRepository(object):
 
         # Now we generate the thumbnails for each dimension, store it
         for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
-            # Work out the correct file name for thumbnail
-            if server_name:
-                file_path = self.filepaths.remote_media_thumbnail_rel(
-                    server_name, file_id, t_width, t_height, t_type, t_method
-                )
-            elif url_cache:
-                file_path = self.filepaths.url_cache_thumbnail_rel(
-                    media_id, t_width, t_height, t_type, t_method
-                )
-            else:
-                file_path = self.filepaths.local_media_thumbnail_rel(
-                    media_id, t_width, t_height, t_type, t_method
-                )
-
             # Generate the thumbnail
             if t_method == "crop":
                 t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@@ -505,9 +623,19 @@ class MediaRepository(object):
                 continue
 
             try:
-                # Write to disk
-                output_path = yield self.write_to_file_and_backup(
-                    t_byte_source, file_path,
+                file_info = FileInfo(
+                    server_name=server_name,
+                    file_id=file_id,
+                    thumbnail=True,
+                    thumbnail_width=t_width,
+                    thumbnail_height=t_height,
+                    thumbnail_method=t_method,
+                    thumbnail_type=t_type,
+                    url_cache=url_cache,
+                )
+
+                output_path = yield self.media_storage.store_file(
+                    t_byte_source, file_info,
                 )
             finally:
                 t_byte_source.close()
@@ -620,7 +748,11 @@ class MediaRepositoryResource(Resource):
 
         self.putChild("upload", UploadResource(hs, media_repo))
         self.putChild("download", DownloadResource(hs, media_repo))
-        self.putChild("thumbnail", ThumbnailResource(hs, media_repo))
+        self.putChild("thumbnail", ThumbnailResource(
+            hs, media_repo, media_repo.media_storage,
+        ))
         self.putChild("identicon", IdenticonResource())
         if hs.config.url_preview_enabled:
-            self.putChild("preview_url", PreviewUrlResource(hs, media_repo))
+            self.putChild("preview_url", PreviewUrlResource(
+                hs, media_repo, media_repo.media_storage,
+            ))
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
new file mode 100644
index 0000000000..3f8d4b9c22
--- /dev/null
+++ b/synapse/rest/media/v1/media_storage.py
@@ -0,0 +1,274 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vecotr Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer, threads
+from twisted.protocols.basic import FileSender
+
+from ._base import Responder
+
+from synapse.util.file_consumer import BackgroundFileConsumer
+from synapse.util.logcontext import make_deferred_yieldable
+
+import contextlib
+import os
+import logging
+import shutil
+import sys
+
+
+logger = logging.getLogger(__name__)
+
+
+class MediaStorage(object):
+    """Responsible for storing/fetching files from local sources.
+
+    Args:
+        local_media_directory (str): Base path where we store media on disk
+        filepaths (MediaFilePaths)
+        storage_providers ([StorageProvider]): List of StorageProvider that are
+            used to fetch and store files.
+    """
+
+    def __init__(self, local_media_directory, filepaths, storage_providers):
+        self.local_media_directory = local_media_directory
+        self.filepaths = filepaths
+        self.storage_providers = storage_providers
+
+    @defer.inlineCallbacks
+    def store_file(self, source, file_info):
+        """Write `source` to the on disk media store, and also any other
+        configured storage providers
+
+        Args:
+            source: A file like object that should be written
+            file_info (FileInfo): Info about the file to store
+
+        Returns:
+            Deferred[str]: the file path written to in the primary media store
+        """
+        path = self._file_info_to_path(file_info)
+        fname = os.path.join(self.local_media_directory, path)
+
+        dirname = os.path.dirname(fname)
+        if not os.path.exists(dirname):
+            os.makedirs(dirname)
+
+        # Write to the main repository
+        yield make_deferred_yieldable(threads.deferToThread(
+            _write_file_synchronously, source, fname,
+        ))
+
+        # Tell the storage providers about the new file. They'll decide
+        # if they should upload it and whether to do so synchronously
+        # or not.
+        for provider in self.storage_providers:
+            yield provider.store_file(path, file_info)
+
+        defer.returnValue(fname)
+
+    @contextlib.contextmanager
+    def store_into_file(self, file_info):
+        """Context manager used to get a file like object to write into, as
+        described by file_info.
+
+        Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
+        like object that can be written to, fname is the absolute path of file
+        on disk, and finish_cb is a function that returns a Deferred.
+
+        fname can be used to read the contents from after upload, e.g. to
+        generate thumbnails.
+
+        finish_cb must be called and waited on after the file has been
+        successfully been written to. Should not be called if there was an
+        error.
+
+        Args:
+            file_info (FileInfo): Info about the file to store
+
+        Example:
+
+            with media_storage.store_into_file(info) as (f, fname, finish_cb):
+                # .. write into f ...
+                yield finish_cb()
+        """
+
+        path = self._file_info_to_path(file_info)
+        fname = os.path.join(self.local_media_directory, path)
+
+        dirname = os.path.dirname(fname)
+        if not os.path.exists(dirname):
+            os.makedirs(dirname)
+
+        finished_called = [False]
+
+        @defer.inlineCallbacks
+        def finish():
+            for provider in self.storage_providers:
+                yield provider.store_file(path, file_info)
+
+            finished_called[0] = True
+
+        try:
+            with open(fname, "wb") as f:
+                yield f, fname, finish
+        except Exception:
+            t, v, tb = sys.exc_info()
+            try:
+                os.remove(fname)
+            except Exception:
+                pass
+            raise t, v, tb
+
+        if not finished_called:
+            raise Exception("Finished callback not called")
+
+    @defer.inlineCallbacks
+    def fetch_media(self, file_info):
+        """Attempts to fetch media described by file_info from the local cache
+        and configured storage providers.
+
+        Args:
+            file_info (FileInfo)
+
+        Returns:
+            Deferred[Responder|None]: Returns a Responder if the file was found,
+                otherwise None.
+        """
+
+        path = self._file_info_to_path(file_info)
+        local_path = os.path.join(self.local_media_directory, path)
+        if os.path.exists(local_path):
+            defer.returnValue(FileResponder(open(local_path, "rb")))
+
+        for provider in self.storage_providers:
+            res = yield provider.fetch(path, file_info)
+            if res:
+                defer.returnValue(res)
+
+        defer.returnValue(None)
+
+    @defer.inlineCallbacks
+    def ensure_media_is_in_local_cache(self, file_info):
+        """Ensures that the given file is in the local cache. Attempts to
+        download it from storage providers if it isn't.
+
+        Args:
+            file_info (FileInfo)
+
+        Returns:
+            Deferred[str]: Full path to local file
+        """
+        path = self._file_info_to_path(file_info)
+        local_path = os.path.join(self.local_media_directory, path)
+        if os.path.exists(local_path):
+            defer.returnValue(local_path)
+
+        dirname = os.path.dirname(local_path)
+        if not os.path.exists(dirname):
+            os.makedirs(dirname)
+
+        for provider in self.storage_providers:
+            res = yield provider.fetch(path, file_info)
+            if res:
+                with res:
+                    consumer = BackgroundFileConsumer(open(local_path, "w"))
+                    yield res.write_to_consumer(consumer)
+                    yield consumer.wait()
+                defer.returnValue(local_path)
+
+        raise Exception("file could not be found")
+
+    def _file_info_to_path(self, file_info):
+        """Converts file_info into a relative path.
+
+        The path is suitable for storing files under a directory, e.g. used to
+        store files on local FS under the base media repository directory.
+
+        Args:
+            file_info (FileInfo)
+
+        Returns:
+            str
+        """
+        if file_info.url_cache:
+            if file_info.thumbnail:
+                return self.filepaths.url_cache_thumbnail_rel(
+                    media_id=file_info.file_id,
+                    width=file_info.thumbnail_width,
+                    height=file_info.thumbnail_height,
+                    content_type=file_info.thumbnail_type,
+                    method=file_info.thumbnail_method,
+                )
+            return self.filepaths.url_cache_filepath_rel(file_info.file_id)
+
+        if file_info.server_name:
+            if file_info.thumbnail:
+                return self.filepaths.remote_media_thumbnail_rel(
+                    server_name=file_info.server_name,
+                    file_id=file_info.file_id,
+                    width=file_info.thumbnail_width,
+                    height=file_info.thumbnail_height,
+                    content_type=file_info.thumbnail_type,
+                    method=file_info.thumbnail_method
+                )
+            return self.filepaths.remote_media_filepath_rel(
+                file_info.server_name, file_info.file_id,
+            )
+
+        if file_info.thumbnail:
+            return self.filepaths.local_media_thumbnail_rel(
+                media_id=file_info.file_id,
+                width=file_info.thumbnail_width,
+                height=file_info.thumbnail_height,
+                content_type=file_info.thumbnail_type,
+                method=file_info.thumbnail_method
+            )
+        return self.filepaths.local_media_filepath_rel(
+            file_info.file_id,
+        )
+
+
+def _write_file_synchronously(source, fname):
+    """Write `source` to the path `fname` synchronously. Should be called
+    from a thread.
+
+    Args:
+        source: A file like object to be written
+        fname (str): Path to write to
+    """
+    dirname = os.path.dirname(fname)
+    if not os.path.exists(dirname):
+        os.makedirs(dirname)
+
+    source.seek(0)  # Ensure we read from the start of the file
+    with open(fname, "wb") as f:
+        shutil.copyfileobj(source, f)
+
+
+class FileResponder(Responder):
+    """Wraps an open file that can be sent to a request.
+
+    Args:
+        open_file (file): A file like object to be streamed ot the client,
+            is closed when finished streaming.
+    """
+    def __init__(self, open_file):
+        self.open_file = open_file
+
+    def write_to_consumer(self, consumer):
+        return FileSender().beginFileTransfer(self.open_file, consumer)
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.open_file.close()
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 723f7043f4..31fe7aa75c 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -12,11 +12,26 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import cgi
+import datetime
+import errno
+import fnmatch
+import itertools
+import logging
+import os
+import re
+import shutil
+import sys
+import traceback
+import ujson as json
+import urlparse
 
 from twisted.web.server import NOT_DONE_YET
 from twisted.internet import defer
 from twisted.web.resource import Resource
 
+from ._base import FileInfo
+
 from synapse.api.errors import (
     SynapseError, Codes,
 )
@@ -25,30 +40,19 @@ from synapse.util.stringutils import random_string
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.http.client import SpiderHttpClient
 from synapse.http.server import (
-    request_handler, respond_with_json_bytes
+    request_handler, respond_with_json_bytes,
+    respond_with_json,
 )
 from synapse.util.async import ObservableDeferred
 from synapse.util.stringutils import is_ascii
 
-import os
-import re
-import fnmatch
-import cgi
-import ujson as json
-import urlparse
-import itertools
-import datetime
-import errno
-import shutil
-
-import logging
 logger = logging.getLogger(__name__)
 
 
 class PreviewUrlResource(Resource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo):
+    def __init__(self, hs, media_repo, media_storage):
         Resource.__init__(self)
 
         self.auth = hs.get_auth()
@@ -61,6 +65,7 @@ class PreviewUrlResource(Resource):
         self.client = SpiderHttpClient(hs)
         self.media_repo = media_repo
         self.primary_base_path = media_repo.primary_base_path
+        self.media_storage = media_storage
 
         self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
 
@@ -78,6 +83,9 @@ class PreviewUrlResource(Resource):
             self._expire_url_cache_data, 10 * 1000
         )
 
+    def render_OPTIONS(self, request):
+        return respond_with_json(request, 200, {}, send_cors=True)
+
     def render_GET(self, request):
         self._async_render_GET(request)
         return NOT_DONE_YET
@@ -178,8 +186,10 @@ class PreviewUrlResource(Resource):
         logger.debug("got media_info of '%s'" % media_info)
 
         if _is_media(media_info['media_type']):
+            file_id = media_info['filesystem_id']
             dims = yield self.media_repo._generate_thumbnails(
-                None, media_info['filesystem_id'], media_info, url_cache=True,
+                None, file_id, file_id, media_info["media_type"],
+                url_cache=True,
             )
 
             og = {
@@ -224,8 +234,10 @@ class PreviewUrlResource(Resource):
 
                 if _is_media(image_info['media_type']):
                     # TODO: make sure we don't choke on white-on-transparent images
+                    file_id = image_info['filesystem_id']
                     dims = yield self.media_repo._generate_thumbnails(
-                        None, image_info['filesystem_id'], image_info, url_cache=True,
+                        None, file_id, file_id, image_info["media_type"],
+                        url_cache=True,
                     )
                     if dims:
                         og["og:image:width"] = dims['width']
@@ -269,21 +281,34 @@ class PreviewUrlResource(Resource):
 
         file_id = datetime.date.today().isoformat() + '_' + random_string(16)
 
-        fpath = self.filepaths.url_cache_filepath_rel(file_id)
-        fname = os.path.join(self.primary_base_path, fpath)
-        self.media_repo._makedirs(fname)
+        file_info = FileInfo(
+            server_name=None,
+            file_id=file_id,
+            url_cache=True,
+        )
 
-        try:
-            with open(fname, "wb") as f:
+        with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+            try:
                 logger.debug("Trying to get url '%s'" % url)
                 length, headers, uri, code = yield self.client.get_file(
                     url, output_stream=f, max_size=self.max_spider_size,
                 )
+            except Exception as e:
                 # FIXME: pass through 404s and other error messages nicely
+                logger.warn("Error downloading %s: %r", url, e)
+                raise SynapseError(
+                    500, "Failed to download content: %s" % (
+                        traceback.format_exception_only(sys.exc_type, e),
+                    ),
+                    Codes.UNKNOWN,
+                )
+            yield finish()
 
-            yield self.media_repo.copy_to_backup(fpath)
-
-            media_type = headers["Content-Type"][0]
+        try:
+            if "Content-Type" in headers:
+                media_type = headers["Content-Type"][0]
+            else:
+                media_type = "application/octet-stream"
             time_now_ms = self.clock.time_msec()
 
             content_disposition = headers.get("Content-Disposition", None)
@@ -323,11 +348,11 @@ class PreviewUrlResource(Resource):
             )
 
         except Exception as e:
-            os.remove(fname)
-            raise SynapseError(
-                500, ("Failed to download content: %s" % e),
-                Codes.UNKNOWN
-            )
+            logger.error("Error handling downloaded %s: %r", url, e)
+            # TODO: we really ought to delete the downloaded file in this
+            # case, since we won't have recorded it in the db, and will
+            # therefore not expire it.
+            raise
 
         defer.returnValue({
             "media_type": media_type,
@@ -348,11 +373,16 @@ class PreviewUrlResource(Resource):
     def _expire_url_cache_data(self):
         """Clean up expired url cache content, media and thumbnails.
         """
-
         # TODO: Delete from backup media store
 
         now = self.clock.time_msec()
 
+        logger.info("Running url preview cache expiry")
+
+        if not (yield self.store.has_completed_background_updates()):
+            logger.info("Still running DB updates; skipping expiry")
+            return
+
         # First we delete expired url cache entries
         media_ids = yield self.store.get_expired_url_cache(now)
 
@@ -426,8 +456,7 @@ class PreviewUrlResource(Resource):
 
         yield self.store.delete_url_cache_media(removed_media)
 
-        if removed_media:
-            logger.info("Deleted %d media from url cache", len(removed_media))
+        logger.info("Deleted %d media from url cache", len(removed_media))
 
 
 def decode_and_calc_og(body, media_uri, request_encoding=None):
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
new file mode 100644
index 0000000000..c188192f2b
--- /dev/null
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -0,0 +1,140 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer, threads
+
+from .media_storage import FileResponder
+
+from synapse.config._base import Config
+from synapse.util.logcontext import preserve_fn
+
+import logging
+import os
+import shutil
+
+
+logger = logging.getLogger(__name__)
+
+
+class StorageProvider(object):
+    """A storage provider is a service that can store uploaded media and
+    retrieve them.
+    """
+    def store_file(self, path, file_info):
+        """Store the file described by file_info. The actual contents can be
+        retrieved by reading the file in file_info.upload_path.
+
+        Args:
+            path (str): Relative path of file in local cache
+            file_info (FileInfo)
+
+        Returns:
+            Deferred
+        """
+        pass
+
+    def fetch(self, path, file_info):
+        """Attempt to fetch the file described by file_info and stream it
+        into writer.
+
+        Args:
+            path (str): Relative path of file in local cache
+            file_info (FileInfo)
+
+        Returns:
+            Deferred(Responder): Returns a Responder if the provider has the file,
+                otherwise returns None.
+        """
+        pass
+
+
+class StorageProviderWrapper(StorageProvider):
+    """Wraps a storage provider and provides various config options
+
+    Args:
+        backend (StorageProvider)
+        store_local (bool): Whether to store new local files or not.
+        store_synchronous (bool): Whether to wait for file to be successfully
+            uploaded, or todo the upload in the backgroud.
+        store_remote (bool): Whether remote media should be uploaded
+    """
+    def __init__(self, backend, store_local, store_synchronous, store_remote):
+        self.backend = backend
+        self.store_local = store_local
+        self.store_synchronous = store_synchronous
+        self.store_remote = store_remote
+
+    def store_file(self, path, file_info):
+        if not file_info.server_name and not self.store_local:
+            return defer.succeed(None)
+
+        if file_info.server_name and not self.store_remote:
+            return defer.succeed(None)
+
+        if self.store_synchronous:
+            return self.backend.store_file(path, file_info)
+        else:
+            # TODO: Handle errors.
+            preserve_fn(self.backend.store_file)(path, file_info)
+            return defer.succeed(None)
+
+    def fetch(self, path, file_info):
+        return self.backend.fetch(path, file_info)
+
+
+class FileStorageProviderBackend(StorageProvider):
+    """A storage provider that stores files in a directory on a filesystem.
+
+    Args:
+        hs (HomeServer)
+        config: The config returned by `parse_config`.
+    """
+
+    def __init__(self, hs, config):
+        self.cache_directory = hs.config.media_store_path
+        self.base_directory = config
+
+    def store_file(self, path, file_info):
+        """See StorageProvider.store_file"""
+
+        primary_fname = os.path.join(self.cache_directory, path)
+        backup_fname = os.path.join(self.base_directory, path)
+
+        dirname = os.path.dirname(backup_fname)
+        if not os.path.exists(dirname):
+            os.makedirs(dirname)
+
+        return threads.deferToThread(
+            shutil.copyfile, primary_fname, backup_fname,
+        )
+
+    def fetch(self, path, file_info):
+        """See StorageProvider.fetch"""
+
+        backup_fname = os.path.join(self.base_directory, path)
+        if os.path.isfile(backup_fname):
+            return FileResponder(open(backup_fname, "rb"))
+
+    @staticmethod
+    def parse_config(config):
+        """Called on startup to parse config supplied. This should parse
+        the config and raise if there is a problem.
+
+        The returned value is passed into the constructor.
+
+        In this case we only care about a single param, the directory, so let's
+        just pull that out.
+        """
+        return Config.ensure_directory(config["directory"])
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 68d56b2b10..58ada49711 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -14,7 +14,10 @@
 # limitations under the License.
 
 
-from ._base import parse_media_id, respond_404, respond_with_file
+from ._base import (
+    parse_media_id, respond_404, respond_with_file, FileInfo,
+    respond_with_responder,
+)
 from twisted.web.resource import Resource
 from synapse.http.servlet import parse_string, parse_integer
 from synapse.http.server import request_handler, set_cors_headers
@@ -30,12 +33,12 @@ logger = logging.getLogger(__name__)
 class ThumbnailResource(Resource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo):
+    def __init__(self, hs, media_repo, media_storage):
         Resource.__init__(self)
 
         self.store = hs.get_datastore()
-        self.filepaths = media_repo.filepaths
         self.media_repo = media_repo
+        self.media_storage = media_storage
         self.dynamic_thumbnails = hs.config.dynamic_thumbnails
         self.server_name = hs.hostname
         self.version_string = hs.version_string
@@ -64,6 +67,7 @@ class ThumbnailResource(Resource):
                 yield self._respond_local_thumbnail(
                     request, media_id, width, height, method, m_type
                 )
+            self.media_repo.mark_recently_accessed(None, media_id)
         else:
             if self.dynamic_thumbnails:
                 yield self._select_or_generate_remote_thumbnail(
@@ -75,20 +79,20 @@ class ThumbnailResource(Resource):
                     request, server_name, media_id,
                     width, height, method, m_type
                 )
+            self.media_repo.mark_recently_accessed(server_name, media_id)
 
     @defer.inlineCallbacks
     def _respond_local_thumbnail(self, request, media_id, width, height,
                                  method, m_type):
         media_info = yield self.store.get_local_media(media_id)
 
-        if not media_info or media_info["quarantined_by"]:
+        if not media_info:
+            respond_404(request)
+            return
+        if media_info["quarantined_by"]:
+            logger.info("Media is quarantined")
             respond_404(request)
             return
-
-        # if media_info["media_type"] == "image/svg+xml":
-        #     file_path = self.filepaths.local_media_filepath(media_id)
-        #     yield respond_with_file(request, media_info["media_type"], file_path)
-        #     return
 
         thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
 
@@ -96,42 +100,39 @@ class ThumbnailResource(Resource):
             thumbnail_info = self._select_thumbnail(
                 width, height, method, m_type, thumbnail_infos
             )
-            t_width = thumbnail_info["thumbnail_width"]
-            t_height = thumbnail_info["thumbnail_height"]
-            t_type = thumbnail_info["thumbnail_type"]
-            t_method = thumbnail_info["thumbnail_method"]
-
-            if media_info["url_cache"]:
-                # TODO: Check the file still exists, if it doesn't we can redownload
-                # it from the url `media_info["url_cache"]`
-                file_path = self.filepaths.url_cache_thumbnail(
-                    media_id, t_width, t_height, t_type, t_method,
-                )
-            else:
-                file_path = self.filepaths.local_media_thumbnail(
-                    media_id, t_width, t_height, t_type, t_method,
-                )
-            yield respond_with_file(request, t_type, file_path)
 
-        else:
-            yield self._respond_default_thumbnail(
-                request, media_info, width, height, method, m_type,
+            file_info = FileInfo(
+                server_name=None, file_id=media_id,
+                url_cache=media_info["url_cache"],
+                thumbnail=True,
+                thumbnail_width=thumbnail_info["thumbnail_width"],
+                thumbnail_height=thumbnail_info["thumbnail_height"],
+                thumbnail_type=thumbnail_info["thumbnail_type"],
+                thumbnail_method=thumbnail_info["thumbnail_method"],
             )
 
+            t_type = file_info.thumbnail_type
+            t_length = thumbnail_info["thumbnail_length"]
+
+            responder = yield self.media_storage.fetch_media(file_info)
+            yield respond_with_responder(request, responder, t_type, t_length)
+        else:
+            logger.info("Couldn't find any generated thumbnails")
+            respond_404(request)
+
     @defer.inlineCallbacks
     def _select_or_generate_local_thumbnail(self, request, media_id, desired_width,
                                             desired_height, desired_method,
                                             desired_type):
         media_info = yield self.store.get_local_media(media_id)
 
-        if not media_info or media_info["quarantined_by"]:
+        if not media_info:
+            respond_404(request)
+            return
+        if media_info["quarantined_by"]:
+            logger.info("Media is quarantined")
             respond_404(request)
             return
-
-        # if media_info["media_type"] == "image/svg+xml":
-        #     file_path = self.filepaths.local_media_filepath(media_id)
-        #     yield respond_with_file(request, media_info["media_type"], file_path)
-        #     return
 
         thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
         for info in thumbnail_infos:
@@ -141,46 +142,43 @@ class ThumbnailResource(Resource):
             t_type = info["thumbnail_type"] == desired_type
 
             if t_w and t_h and t_method and t_type:
-                if media_info["url_cache"]:
-                    # TODO: Check the file still exists, if it doesn't we can redownload
-                    # it from the url `media_info["url_cache"]`
-                    file_path = self.filepaths.url_cache_thumbnail(
-                        media_id, desired_width, desired_height, desired_type,
-                        desired_method,
-                    )
-                else:
-                    file_path = self.filepaths.local_media_thumbnail(
-                        media_id, desired_width, desired_height, desired_type,
-                        desired_method,
-                    )
-                yield respond_with_file(request, desired_type, file_path)
-                return
-
-        logger.debug("We don't have a local thumbnail of that size. Generating")
+                file_info = FileInfo(
+                    server_name=None, file_id=media_id,
+                    url_cache=media_info["url_cache"],
+                    thumbnail=True,
+                    thumbnail_width=info["thumbnail_width"],
+                    thumbnail_height=info["thumbnail_height"],
+                    thumbnail_type=info["thumbnail_type"],
+                    thumbnail_method=info["thumbnail_method"],
+                )
+
+                t_type = file_info.thumbnail_type
+                t_length = info["thumbnail_length"]
+
+                responder = yield self.media_storage.fetch_media(file_info)
+                if responder:
+                    yield respond_with_responder(request, responder, t_type, t_length)
+                    return
+
+        logger.debug("We don't have a thumbnail of that size. Generating")
 
         # Okay, so we generate one.
         file_path = yield self.media_repo.generate_local_exact_thumbnail(
-            media_id, desired_width, desired_height, desired_method, desired_type
+            media_id, desired_width, desired_height, desired_method, desired_type,
+            url_cache=media_info["url_cache"],
         )
 
         if file_path:
             yield respond_with_file(request, desired_type, file_path)
         else:
-            yield self._respond_default_thumbnail(
-                request, media_info, desired_width, desired_height,
-                desired_method, desired_type,
-            )
+            logger.warn("Failed to generate thumbnail")
+            respond_404(request)
 
     @defer.inlineCallbacks
     def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
                                              desired_width, desired_height,
                                              desired_method, desired_type):
-        media_info = yield self.media_repo.get_remote_media(server_name, media_id)
-
-        # if media_info["media_type"] == "image/svg+xml":
-        #     file_path = self.filepaths.remote_media_filepath(server_name, media_id)
-        #     yield respond_with_file(request, media_info["media_type"], file_path)
-        #     return
+        media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
 
         thumbnail_infos = yield self.store.get_remote_media_thumbnails(
             server_name, media_id,
@@ -195,14 +193,24 @@ class ThumbnailResource(Resource):
             t_type = info["thumbnail_type"] == desired_type
 
             if t_w and t_h and t_method and t_type:
-                file_path = self.filepaths.remote_media_thumbnail(
-                    server_name, file_id, desired_width, desired_height,
-                    desired_type, desired_method,
+                file_info = FileInfo(
+                    server_name=server_name, file_id=media_info["filesystem_id"],
+                    thumbnail=True,
+                    thumbnail_width=info["thumbnail_width"],
+                    thumbnail_height=info["thumbnail_height"],
+                    thumbnail_type=info["thumbnail_type"],
+                    thumbnail_method=info["thumbnail_method"],
                 )
-                yield respond_with_file(request, desired_type, file_path)
-                return
 
-        logger.debug("We don't have a local thumbnail of that size. Generating")
+                t_type = file_info.thumbnail_type
+                t_length = info["thumbnail_length"]
+
+                responder = yield self.media_storage.fetch_media(file_info)
+                if responder:
+                    yield respond_with_responder(request, responder, t_type, t_length)
+                    return
+
+        logger.debug("We don't have a thumbnail of that size. Generating")
 
         # Okay, so we generate one.
         file_path = yield self.media_repo.generate_remote_exact_thumbnail(
@@ -213,22 +221,16 @@ class ThumbnailResource(Resource):
         if file_path:
             yield respond_with_file(request, desired_type, file_path)
         else:
-            yield self._respond_default_thumbnail(
-                request, media_info, desired_width, desired_height,
-                desired_method, desired_type,
-            )
+            logger.warn("Failed to generate thumbnail")
+            respond_404(request)
 
     @defer.inlineCallbacks
     def _respond_remote_thumbnail(self, request, server_name, media_id, width,
                                   height, method, m_type):
         # TODO: Don't download the whole remote file
-        # We should proxy the thumbnail from the remote server instead.
-        media_info = yield self.media_repo.get_remote_media(server_name, media_id)
-
-        # if media_info["media_type"] == "image/svg+xml":
-        #     file_path = self.filepaths.remote_media_filepath(server_name, media_id)
-        #     yield respond_with_file(request, media_info["media_type"], file_path)
-        #     return
+        # We should proxy the thumbnail from the remote server instead of
+        # downloading the remote file and generating our own thumbnails.
+        media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
 
         thumbnail_infos = yield self.store.get_remote_media_thumbnails(
             server_name, media_id,
@@ -238,59 +240,23 @@ class ThumbnailResource(Resource):
             thumbnail_info = self._select_thumbnail(
                 width, height, method, m_type, thumbnail_infos
             )
-            t_width = thumbnail_info["thumbnail_width"]
-            t_height = thumbnail_info["thumbnail_height"]
-            t_type = thumbnail_info["thumbnail_type"]
-            t_method = thumbnail_info["thumbnail_method"]
-            file_id = thumbnail_info["filesystem_id"]
+            file_info = FileInfo(
+                server_name=server_name, file_id=media_info["filesystem_id"],
+                thumbnail=True,
+                thumbnail_width=thumbnail_info["thumbnail_width"],
+                thumbnail_height=thumbnail_info["thumbnail_height"],
+                thumbnail_type=thumbnail_info["thumbnail_type"],
+                thumbnail_method=thumbnail_info["thumbnail_method"],
+            )
+
+            t_type = file_info.thumbnail_type
             t_length = thumbnail_info["thumbnail_length"]
 
-            file_path = self.filepaths.remote_media_thumbnail(
-                server_name, file_id, t_width, t_height, t_type, t_method,
-            )
-            yield respond_with_file(request, t_type, file_path, t_length)
+            responder = yield self.media_storage.fetch_media(file_info)
+            yield respond_with_responder(request, responder, t_type, t_length)
         else:
-            yield self._respond_default_thumbnail(
-                request, media_info, width, height, method, m_type,
-            )
-
-    @defer.inlineCallbacks
-    def _respond_default_thumbnail(self, request, media_info, width, height,
-                                   method, m_type):
-        # XXX: how is this meant to work? store.get_default_thumbnails
-        # appears to always return [] so won't this always 404?
-        media_type = media_info["media_type"]
-        top_level_type = media_type.split("/")[0]
-        sub_type = media_type.split("/")[-1].split(";")[0]
-        thumbnail_infos = yield self.store.get_default_thumbnails(
-            top_level_type, sub_type,
-        )
-        if not thumbnail_infos:
-            thumbnail_infos = yield self.store.get_default_thumbnails(
-                top_level_type, "_default",
-            )
-        if not thumbnail_infos:
-            thumbnail_infos = yield self.store.get_default_thumbnails(
-                "_default", "_default",
-            )
-        if not thumbnail_infos:
+            logger.info("Failed to find any generated thumbnails")
             respond_404(request)
-            return
-
-        thumbnail_info = self._select_thumbnail(
-            width, height, "crop", m_type, thumbnail_infos
-        )
-
-        t_width = thumbnail_info["thumbnail_width"]
-        t_height = thumbnail_info["thumbnail_height"]
-        t_type = thumbnail_info["thumbnail_type"]
-        t_method = thumbnail_info["thumbnail_method"]
-        t_length = thumbnail_info["thumbnail_length"]
-
-        file_path = self.filepaths.default_thumbnail(
-            top_level_type, sub_type, t_width, t_height, t_type, t_method,
-        )
-        yield respond_with_file(request, t_type, file_path, t_length)
 
     def _select_thumbnail(self, desired_width, desired_height, desired_method,
                           desired_type, thumbnail_infos):
diff --git a/synapse/server.py b/synapse/server.py
index 10e3e9a4f1..fbd602d40e 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -39,20 +39,23 @@ from synapse.federation.transaction_queue import TransactionQueue
 from synapse.handlers import Handlers
 from synapse.handlers.appservice import ApplicationServicesHandler
 from synapse.handlers.auth import AuthHandler, MacaroonGeneartor
+from synapse.handlers.deactivate_account import DeactivateAccountHandler
 from synapse.handlers.devicemessage import DeviceMessageHandler
 from synapse.handlers.device import DeviceHandler
 from synapse.handlers.e2e_keys import E2eKeysHandler
 from synapse.handlers.presence import PresenceHandler
 from synapse.handlers.room_list import RoomListHandler
+from synapse.handlers.set_password import SetPasswordHandler
 from synapse.handlers.sync import SyncHandler
 from synapse.handlers.typing import TypingHandler
 from synapse.handlers.events import EventHandler, EventStreamHandler
 from synapse.handlers.initial_sync import InitialSyncHandler
 from synapse.handlers.receipts import ReceiptsHandler
 from synapse.handlers.read_marker import ReadMarkerHandler
-from synapse.handlers.user_directory import UserDirectoyHandler
+from synapse.handlers.user_directory import UserDirectoryHandler
 from synapse.handlers.groups_local import GroupsLocalHandler
 from synapse.handlers.profile import ProfileHandler
+from synapse.handlers.message import EventCreationHandler
 from synapse.groups.groups_server import GroupsServerHandler
 from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning
 from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
@@ -60,8 +63,11 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient
 from synapse.notifier import Notifier
 from synapse.push.action_generator import ActionGenerator
 from synapse.push.pusherpool import PusherPool
-from synapse.rest.media.v1.media_repository import MediaRepository
-from synapse.state import StateHandler
+from synapse.rest.media.v1.media_repository import (
+    MediaRepository,
+    MediaRepositoryResource,
+)
+from synapse.state import StateHandler, StateResolutionHandler
 from synapse.storage import DataStore
 from synapse.streams.events import EventSources
 from synapse.util import Clock
@@ -90,18 +96,14 @@ class HomeServer(object):
     """
 
     DEPENDENCIES = [
-        'config',
-        'clock',
         'http_client',
         'db_pool',
-        'persistence_service',
         'replication_layer',
-        'datastore',
         'handlers',
         'v1auth',
         'auth',
-        'rest_servlet_factory',
         'state_handler',
+        'state_resolution_handler',
         'presence_handler',
         'sync_handler',
         'typing_handler',
@@ -117,19 +119,11 @@ class HomeServer(object):
         'application_service_handler',
         'device_message_handler',
         'profile_handler',
+        'event_creation_handler',
+        'deactivate_account_handler',
+        'set_password_handler',
         'notifier',
-        'distributor',
-        'client_resource',
-        'resource_for_federation',
-        'resource_for_static_content',
-        'resource_for_web_client',
-        'resource_for_content_repo',
-        'resource_for_server_key',
-        'resource_for_server_key_v2',
-        'resource_for_media_repository',
-        'resource_for_metrics',
         'event_sources',
-        'ratelimiter',
         'keyring',
         'pusherpool',
         'event_builder_factory',
@@ -137,6 +131,7 @@ class HomeServer(object):
         'http_client_context_factory',
         'simple_http_client',
         'media_repository',
+        'media_repository_resource',
         'federation_transport_client',
         'federation_sender',
         'receipts_handler',
@@ -183,6 +178,21 @@ class HomeServer(object):
     def is_mine_id(self, string):
         return string.split(":", 1)[1] == self.hostname
 
+    def get_clock(self):
+        return self.clock
+
+    def get_datastore(self):
+        return self.datastore
+
+    def get_config(self):
+        return self.config
+
+    def get_distributor(self):
+        return self.distributor
+
+    def get_ratelimiter(self):
+        return self.ratelimiter
+
     def build_replication_layer(self):
         return initialize_http_replication(self)
 
@@ -217,6 +227,9 @@ class HomeServer(object):
     def build_state_handler(self):
         return StateHandler(self)
 
+    def build_state_resolution_handler(self):
+        return StateResolutionHandler(self)
+
     def build_presence_handler(self):
         return PresenceHandler(self)
 
@@ -265,6 +278,15 @@ class HomeServer(object):
     def build_profile_handler(self):
         return ProfileHandler(self)
 
+    def build_event_creation_handler(self):
+        return EventCreationHandler(self)
+
+    def build_deactivate_account_handler(self):
+        return DeactivateAccountHandler(self)
+
+    def build_set_password_handler(self):
+        return SetPasswordHandler(self)
+
     def build_event_sources(self):
         return EventSources(self)
 
@@ -294,6 +316,28 @@ class HomeServer(object):
             **self.db_config.get("args", {})
         )
 
+    def get_db_conn(self, run_new_connection=True):
+        """Makes a new connection to the database, skipping the db pool
+
+        Returns:
+            Connection: a connection object implementing the PEP-249 spec
+        """
+        # Any param beginning with cp_ is a parameter for adbapi, and should
+        # not be passed to the database engine.
+        db_params = {
+            k: v for k, v in self.db_config.get("args", {}).items()
+            if not k.startswith("cp_")
+        }
+        db_conn = self.database_engine.module.connect(**db_params)
+        if run_new_connection:
+            self.database_engine.on_new_connection(db_conn)
+        return db_conn
+
+    def build_media_repository_resource(self):
+        # build the media repo resource. This indirects through the HomeServer
+        # to ensure that we only have a single instance of
+        return MediaRepositoryResource(self)
+
     def build_media_repository(self):
         return MediaRepository(self)
 
@@ -321,7 +365,7 @@ class HomeServer(object):
         return ActionGenerator(self)
 
     def build_user_directory_handler(self):
-        return UserDirectoyHandler(self)
+        return UserDirectoryHandler(self)
 
     def build_groups_local_handler(self):
         return GroupsLocalHandler(self)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index e8c0386b7f..c3a9a3847b 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -3,10 +3,14 @@ import synapse.federation.transaction_queue
 import synapse.federation.transport.client
 import synapse.handlers
 import synapse.handlers.auth
+import synapse.handlers.deactivate_account
 import synapse.handlers.device
 import synapse.handlers.e2e_keys
-import synapse.storage
+import synapse.handlers.set_password
+import synapse.rest.media.v1.media_repository
 import synapse.state
+import synapse.storage
+
 
 class HomeServer(object):
     def get_auth(self) -> synapse.api.auth.Auth:
@@ -30,8 +34,23 @@ class HomeServer(object):
     def get_state_handler(self) -> synapse.state.StateHandler:
         pass
 
+    def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
+        pass
+
+    def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
+        pass
+
+    def get_set_password_handler(self) -> synapse.handlers.set_password.SetPasswordHandler:
+        pass
+
     def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue:
         pass
 
     def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient:
         pass
+
+    def get_media_repository_resource(self) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource:
+        pass
+
+    def get_media_repository(self) -> synapse.rest.media.v1.media_repository.MediaRepository:
+        pass
diff --git a/synapse/state.py b/synapse/state.py
index 9e624b4937..cc93bbcb6b 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -58,7 +58,11 @@ class _StateCacheEntry(object):
     __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
 
     def __init__(self, state, state_group, prev_group=None, delta_ids=None):
+        # dict[(str, str), str] map  from (type, state_key) to event_id
         self.state = frozendict(state)
+
+        # the ID of a state group if one and only one is involved.
+        # otherwise, None otherwise?
         self.state_group = state_group
 
         self.prev_group = prev_group
@@ -81,31 +85,19 @@ class _StateCacheEntry(object):
 
 
 class StateHandler(object):
-    """ Responsible for doing state conflict resolution.
+    """Fetches bits of state from the stores, and does state resolution
+    where necessary
     """
 
     def __init__(self, hs):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.hs = hs
-
-        # dict of set of event_ids -> _StateCacheEntry.
-        self._state_cache = None
-        self.resolve_linearizer = Linearizer(name="state_resolve_lock")
+        self._state_resolution_handler = hs.get_state_resolution_handler()
 
     def start_caching(self):
-        logger.debug("start_caching")
-
-        self._state_cache = ExpiringCache(
-            cache_name="state_cache",
-            clock=self.clock,
-            max_len=SIZE_OF_CACHE,
-            expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
-            iterable=True,
-            reset_expiry_on_get=True,
-        )
-
-        self._state_cache.start()
+        # TODO: remove this shim
+        self._state_resolution_handler.start_caching()
 
     @defer.inlineCallbacks
     def get_current_state(self, room_id, event_type=None, state_key="",
@@ -127,7 +119,7 @@ class StateHandler(object):
             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
 
         logger.debug("calling resolve_state_groups from get_current_state")
-        ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+        ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
         state = ret.state
 
         if event_type:
@@ -146,19 +138,27 @@ class StateHandler(object):
         defer.returnValue(state)
 
     @defer.inlineCallbacks
-    def get_current_state_ids(self, room_id, event_type=None, state_key="",
-                              latest_event_ids=None):
+    def get_current_state_ids(self, room_id, latest_event_ids=None):
+        """Get the current state, or the state at a set of events, for a room
+
+        Args:
+            room_id (str):
+
+            latest_event_ids (iterable[str]|None): if given, the forward
+                extremities to resolve. If None, we look them up from the
+                database (via a cache)
+
+        Returns:
+            Deferred[dict[(str, str), str)]]: the state dict, mapping from
+                (event_type, state_key) -> event_id
+        """
         if not latest_event_ids:
             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
 
         logger.debug("calling resolve_state_groups from get_current_state_ids")
-        ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+        ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
         state = ret.state
 
-        if event_type:
-            defer.returnValue(state.get((event_type, state_key)))
-            return
-
         defer.returnValue(state)
 
     @defer.inlineCallbacks
@@ -166,7 +166,7 @@ class StateHandler(object):
         if not latest_event_ids:
             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
         logger.debug("calling resolve_state_groups from get_current_user_in_room")
-        entry = yield self.resolve_state_groups(room_id, latest_event_ids)
+        entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
         joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
         defer.returnValue(joined_users)
 
@@ -175,7 +175,7 @@ class StateHandler(object):
         if not latest_event_ids:
             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
         logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
-        entry = yield self.resolve_state_groups(room_id, latest_event_ids)
+        entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
         joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
         defer.returnValue(joined_hosts)
 
@@ -183,8 +183,15 @@ class StateHandler(object):
     def compute_event_context(self, event, old_state=None):
         """Build an EventContext structure for the event.
 
+        This works out what the current state should be for the event, and
+        generates a new state group if necessary.
+
         Args:
             event (synapse.events.EventBase):
+            old_state (dict|None): The state at the event if it can't be
+                calculated from existing events. This is normally only specified
+                when receiving an event from federation where we don't have the
+                prev events for, e.g. when backfilling.
         Returns:
             synapse.events.snapshot.EventContext:
         """
@@ -208,15 +215,22 @@ class StateHandler(object):
                 context.current_state_ids = {}
                 context.prev_state_ids = {}
             context.prev_state_events = []
-            context.state_group = self.store.get_next_state_group()
+
+            # We don't store state for outliers, so we don't generate a state
+            # froup for it.
+            context.state_group = None
+
             defer.returnValue(context)
 
         if old_state:
+            # We already have the state, so we don't need to calculate it.
+            # Let's just correctly fill out the context and create a
+            # new state group for it.
+
             context = EventContext()
             context.prev_state_ids = {
                 (s.type, s.state_key): s.event_id for s in old_state
             }
-            context.state_group = self.store.get_next_state_group()
 
             if event.is_state():
                 key = (event.type, event.state_key)
@@ -229,11 +243,19 @@ class StateHandler(object):
             else:
                 context.current_state_ids = context.prev_state_ids
 
+            context.state_group = yield self.store.store_state_group(
+                event.event_id,
+                event.room_id,
+                prev_group=None,
+                delta_ids=None,
+                current_state_ids=context.current_state_ids,
+            )
+
             context.prev_state_events = []
             defer.returnValue(context)
 
         logger.debug("calling resolve_state_groups from compute_event_context")
-        entry = yield self.resolve_state_groups(
+        entry = yield self.resolve_state_groups_for_events(
             event.room_id, [e for e, _ in event.prev_events],
         )
 
@@ -242,7 +264,8 @@ class StateHandler(object):
         context = EventContext()
         context.prev_state_ids = curr_state
         if event.is_state():
-            context.state_group = self.store.get_next_state_group()
+            # If this is a state event then we need to create a new state
+            # group for the state after this event.
 
             key = (event.type, event.state_key)
             if key in context.prev_state_ids:
@@ -253,38 +276,57 @@ class StateHandler(object):
             context.current_state_ids[key] = event.event_id
 
             if entry.state_group:
+                # If the state at the event has a state group assigned then
+                # we can use that as the prev group
                 context.prev_group = entry.state_group
                 context.delta_ids = {
                     key: event.event_id
                 }
             elif entry.prev_group:
+                # If the state at the event only has a prev group, then we can
+                # use that as a prev group too.
                 context.prev_group = entry.prev_group
                 context.delta_ids = dict(entry.delta_ids)
                 context.delta_ids[key] = event.event_id
+
+            context.state_group = yield self.store.store_state_group(
+                event.event_id,
+                event.room_id,
+                prev_group=context.prev_group,
+                delta_ids=context.delta_ids,
+                current_state_ids=context.current_state_ids,
+            )
         else:
+            context.current_state_ids = context.prev_state_ids
+            context.prev_group = entry.prev_group
+            context.delta_ids = entry.delta_ids
+
             if entry.state_group is None:
-                entry.state_group = self.store.get_next_state_group()
+                entry.state_group = yield self.store.store_state_group(
+                    event.event_id,
+                    event.room_id,
+                    prev_group=entry.prev_group,
+                    delta_ids=entry.delta_ids,
+                    current_state_ids=context.current_state_ids,
+                )
                 entry.state_id = entry.state_group
 
             context.state_group = entry.state_group
-            context.current_state_ids = context.prev_state_ids
-            context.prev_group = entry.prev_group
-            context.delta_ids = entry.delta_ids
 
         context.prev_state_events = []
         defer.returnValue(context)
 
     @defer.inlineCallbacks
-    @log_function
-    def resolve_state_groups(self, room_id, event_ids):
+    def resolve_state_groups_for_events(self, room_id, event_ids):
         """ Given a list of event_ids this method fetches the state at each
         event, resolves conflicts between them and returns them.
 
+        Args:
+            room_id (str):
+            event_ids (list[str]):
+
         Returns:
-            a Deferred tuple of (`state_group`, `state`, `prev_state`).
-            `state_group` is the name of a state group if one and only one is
-            involved. `state` is a map from (type, state_key) to event, and
-            `prev_state` is a list of event ids.
+            Deferred[_StateCacheEntry]: resolved state
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
@@ -295,13 +337,7 @@ class StateHandler(object):
             room_id, event_ids
         )
 
-        logger.debug(
-            "resolve_state_groups state_groups %s",
-            state_groups_ids.keys()
-        )
-
-        group_names = frozenset(state_groups_ids.keys())
-        if len(group_names) == 1:
+        if len(state_groups_ids) == 1:
             name, state_list = state_groups_ids.items().pop()
 
             prev_group, delta_ids = yield self.store.get_state_group_delta(name)
@@ -313,6 +349,92 @@ class StateHandler(object):
                 delta_ids=delta_ids,
             ))
 
+        result = yield self._state_resolution_handler.resolve_state_groups(
+            room_id, state_groups_ids, self._state_map_factory,
+        )
+        defer.returnValue(result)
+
+    def _state_map_factory(self, ev_ids):
+        return self.store.get_events(
+            ev_ids, get_prev_content=False, check_redacted=False,
+        )
+
+    def resolve_events(self, state_sets, event):
+        logger.info(
+            "Resolving state for %s with %d groups", event.room_id, len(state_sets)
+        )
+        state_set_ids = [{
+            (ev.type, ev.state_key): ev.event_id
+            for ev in st
+        } for st in state_sets]
+
+        state_map = {
+            ev.event_id: ev
+            for st in state_sets
+            for ev in st
+        }
+
+        with Measure(self.clock, "state._resolve_events"):
+            new_state = resolve_events_with_state_map(state_set_ids, state_map)
+
+        new_state = {
+            key: state_map[ev_id] for key, ev_id in new_state.items()
+        }
+
+        return new_state
+
+
+class StateResolutionHandler(object):
+    """Responsible for doing state conflict resolution.
+
+    Note that the storage layer depends on this handler, so all functions must
+    be storage-independent.
+    """
+    def __init__(self, hs):
+        self.clock = hs.get_clock()
+
+        # dict of set of event_ids -> _StateCacheEntry.
+        self._state_cache = None
+        self.resolve_linearizer = Linearizer(name="state_resolve_lock")
+
+    def start_caching(self):
+        logger.debug("start_caching")
+
+        self._state_cache = ExpiringCache(
+            cache_name="state_cache",
+            clock=self.clock,
+            max_len=SIZE_OF_CACHE,
+            expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
+            iterable=True,
+            reset_expiry_on_get=True,
+        )
+
+        self._state_cache.start()
+
+    @defer.inlineCallbacks
+    @log_function
+    def resolve_state_groups(self, room_id, state_groups_ids, state_map_factory):
+        """Resolves conflicts between a set of state groups
+
+        Always generates a new state group (unless we hit the cache), so should
+        not be called for a single state group
+
+        Args:
+            room_id (str): room we are resolving for (used for logging)
+            state_groups_ids (dict[int, dict[(str, str), str]]):
+                 map from state group id to the state in that state group
+                (where 'state' is a map from state key to event id)
+
+        Returns:
+            Deferred[_StateCacheEntry]: resolved state
+        """
+        logger.debug(
+            "resolve_state_groups state_groups %s",
+            state_groups_ids.keys()
+        )
+
+        group_names = frozenset(state_groups_ids.keys())
+
         with (yield self.resolve_linearizer.queue(group_names)):
             if self._state_cache is not None:
                 cache = self._state_cache.get(group_names, None)
@@ -341,17 +463,19 @@ class StateHandler(object):
             if conflicted_state:
                 logger.info("Resolving conflicted state for %r", room_id)
                 with Measure(self.clock, "state._resolve_events"):
-                    new_state = yield resolve_events(
+                    new_state = yield resolve_events_with_factory(
                         state_groups_ids.values(),
-                        state_map_factory=lambda ev_ids: self.store.get_events(
-                            ev_ids, get_prev_content=False, check_redacted=False,
-                        ),
+                        state_map_factory=state_map_factory,
                     )
             else:
                 new_state = {
                     key: e_ids.pop() for key, e_ids in state.items()
                 }
 
+            # if the new state matches any of the input state groups, we can
+            # use that state group again. Otherwise we will generate a state_id
+            # which will be used as a cache key for future resolutions, but
+            # not get persisted.
             state_group = None
             new_state_event_ids = frozenset(new_state.values())
             for sg, events in state_groups_ids.items():
@@ -388,30 +512,6 @@ class StateHandler(object):
 
             defer.returnValue(cache)
 
-    def resolve_events(self, state_sets, event):
-        logger.info(
-            "Resolving state for %s with %d groups", event.room_id, len(state_sets)
-        )
-        state_set_ids = [{
-            (ev.type, ev.state_key): ev.event_id
-            for ev in st
-        } for st in state_sets]
-
-        state_map = {
-            ev.event_id: ev
-            for st in state_sets
-            for ev in st
-        }
-
-        with Measure(self.clock, "state._resolve_events"):
-            new_state = resolve_events(state_set_ids, state_map)
-
-        new_state = {
-            key: state_map[ev_id] for key, ev_id in new_state.items()
-        }
-
-        return new_state
-
 
 def _ordered_events(events):
     def key_func(e):
@@ -420,19 +520,17 @@ def _ordered_events(events):
     return sorted(events, key=key_func)
 
 
-def resolve_events(state_sets, state_map_factory):
+def resolve_events_with_state_map(state_sets, state_map):
     """
     Args:
         state_sets(list): List of dicts of (type, state_key) -> event_id,
             which are the different state groups to resolve.
-        state_map_factory(dict|callable): If callable, then will be called
-            with a list of event_ids that are needed, and should return with
-            a Deferred of dict of event_id to event. Otherwise, should be
-            a dict from event_id to event of all events in state_sets.
+        state_map(dict): a dict from event_id to event, for all events in
+            state_sets.
 
     Returns
-        dict[(str, str), synapse.events.FrozenEvent] is a map from
-        (type, state_key) to event.
+        dict[(str, str), str]:
+            a map from (type, state_key) to event_id.
     """
     if len(state_sets) == 1:
         return state_sets[0]
@@ -441,13 +539,6 @@ def resolve_events(state_sets, state_map_factory):
         state_sets,
     )
 
-    if callable(state_map_factory):
-        return _resolve_with_state_fac(
-            unconflicted_state, conflicted_state, state_map_factory
-        )
-
-    state_map = state_map_factory
-
     auth_events = _create_auth_events_from_maps(
         unconflicted_state, conflicted_state, state_map
     )
@@ -461,6 +552,21 @@ def _seperate(state_sets):
     """Takes the state_sets and figures out which keys are conflicted and
     which aren't. i.e., which have multiple different event_ids associated
     with them in different state sets.
+
+    Args:
+        state_sets(list[dict[(str, str), str]]):
+            List of dicts of (type, state_key) -> event_id, which are the
+            different state groups to resolve.
+
+    Returns:
+        (dict[(str, str), str], dict[(str, str), set[str]]):
+            A tuple of (unconflicted_state, conflicted_state), where:
+
+            unconflicted_state is a dict mapping (type, state_key)->event_id
+            for unconflicted state keys.
+
+            conflicted_state is a dict mapping (type, state_key) to a set of
+            event ids for conflicted state keys.
     """
     unconflicted_state = dict(state_sets[0])
     conflicted_state = {}
@@ -491,8 +597,26 @@ def _seperate(state_sets):
 
 
 @defer.inlineCallbacks
-def _resolve_with_state_fac(unconflicted_state, conflicted_state,
-                            state_map_factory):
+def resolve_events_with_factory(state_sets, state_map_factory):
+    """
+    Args:
+        state_sets(list): List of dicts of (type, state_key) -> event_id,
+            which are the different state groups to resolve.
+        state_map_factory(func): will be called
+            with a list of event_ids that are needed, and should return with
+            a Deferred of dict of event_id to event.
+
+    Returns
+        Deferred[dict[(str, str), str]]:
+            a map from (type, state_key) to event_id.
+    """
+    if len(state_sets) == 1:
+        defer.returnValue(state_sets[0])
+
+    unconflicted_state, conflicted_state = _seperate(
+        state_sets,
+    )
+
     needed_events = set(
         event_id
         for event_ids in conflicted_state.itervalues()
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index d01d46338a..f8fbd02ceb 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -124,7 +124,6 @@ class DataStore(RoomMemberStore, RoomStore,
         )
 
         self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
-        self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
         self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index e6eefdd6fe..68125006eb 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -291,33 +291,33 @@ class SQLBaseStore(object):
 
     @defer.inlineCallbacks
     def runInteraction(self, desc, func, *args, **kwargs):
-        """Wraps the .runInteraction() method on the underlying db_pool."""
-        current_context = LoggingContext.current_context()
+        """Starts a transaction on the database and runs a given function
 
-        start_time = time.time() * 1000
+        Arguments:
+            desc (str): description of the transaction, for logging and metrics
+            func (func): callback function, which will be called with a
+                database transaction (twisted.enterprise.adbapi.Transaction) as
+                its first argument, followed by `args` and `kwargs`.
+
+            args (list): positional args to pass to `func`
+            kwargs (dict): named args to pass to `func`
+
+        Returns:
+            Deferred: The result of func
+        """
+        current_context = LoggingContext.current_context()
 
         after_callbacks = []
         final_callbacks = []
 
         def inner_func(conn, *args, **kwargs):
-            with LoggingContext("runInteraction") as context:
-                sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
-
-                if self.database_engine.is_connection_closed(conn):
-                    logger.debug("Reconnecting closed database connection")
-                    conn.reconnect()
-
-                current_context.copy_to(context)
-                return self._new_transaction(
-                    conn, desc, after_callbacks, final_callbacks, current_context,
-                    func, *args, **kwargs
-                )
+            return self._new_transaction(
+                conn, desc, after_callbacks, final_callbacks, current_context,
+                func, *args, **kwargs
+            )
 
         try:
-            with PreserveLoggingContext():
-                result = yield self._db_pool.runWithConnection(
-                    inner_func, *args, **kwargs
-                )
+            result = yield self.runWithConnection(inner_func, *args, **kwargs)
 
             for after_callback, after_args, after_kwargs in after_callbacks:
                 after_callback(*after_args, **after_kwargs)
@@ -329,14 +329,27 @@ class SQLBaseStore(object):
 
     @defer.inlineCallbacks
     def runWithConnection(self, func, *args, **kwargs):
-        """Wraps the .runInteraction() method on the underlying db_pool."""
+        """Wraps the .runWithConnection() method on the underlying db_pool.
+
+        Arguments:
+            func (func): callback function, which will be called with a
+                database connection (twisted.enterprise.adbapi.Connection) as
+                its first argument, followed by `args` and `kwargs`.
+            args (list): positional args to pass to `func`
+            kwargs (dict): named args to pass to `func`
+
+        Returns:
+            Deferred: The result of func
+        """
         current_context = LoggingContext.current_context()
 
         start_time = time.time() * 1000
 
         def inner_func(conn, *args, **kwargs):
             with LoggingContext("runWithConnection") as context:
-                sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+                sched_duration_ms = time.time() * 1000 - start_time
+                sql_scheduling_timer.inc_by(sched_duration_ms)
+                current_context.add_database_scheduled(sched_duration_ms)
 
                 if self.database_engine.is_connection_closed(conn):
                     logger.debug("Reconnecting closed database connection")
@@ -495,6 +508,7 @@ class SQLBaseStore(object):
             Deferred(bool): True if a new entry was created, False if an
                 existing one was updated.
         """
+        attempts = 0
         while True:
             try:
                 result = yield self.runInteraction(
@@ -504,6 +518,12 @@ class SQLBaseStore(object):
                 )
                 defer.returnValue(result)
             except self.database_engine.module.IntegrityError as e:
+                attempts += 1
+                if attempts >= 5:
+                    # don't retry forever, because things other than races
+                    # can cause IntegrityErrors
+                    raise
+
                 # presumably we raced with another transaction: let's retry.
                 logger.warn(
                     "IntegrityError when upserting into %s; retrying: %s",
@@ -547,7 +567,7 @@ class SQLBaseStore(object):
     def _simple_select_one(self, table, keyvalues, retcols,
                            allow_none=False, desc="_simple_select_one"):
         """Executes a SELECT query on the named table, which is expected to
-        return a single row, returning a single column from it.
+        return a single row, returning multiple columns from it.
 
         Args:
             table : string giving the table name
@@ -600,20 +620,18 @@ class SQLBaseStore(object):
 
     @staticmethod
     def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
-        if keyvalues:
-            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
-        else:
-            where = ""
-
         sql = (
-            "SELECT %(retcol)s FROM %(table)s %(where)s"
+            "SELECT %(retcol)s FROM %(table)s"
         ) % {
             "retcol": retcol,
             "table": table,
-            "where": where,
         }
 
-        txn.execute(sql, keyvalues.values())
+        if keyvalues:
+            sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
+            txn.execute(sql, keyvalues.values())
+        else:
+            txn.execute(sql)
 
         return [r[0] for r in txn]
 
@@ -624,7 +642,7 @@ class SQLBaseStore(object):
 
         Args:
             table (str): table name
-            keyvalues (dict): column names and values to select the rows with
+            keyvalues (dict|None): column names and values to select the rows with
             retcol (str): column whos value we wish to retrieve.
 
         Returns:
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index c8a1eb016b..56a0bde549 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -222,9 +222,12 @@ class AccountDataStore(SQLBaseStore):
         """
         content_json = json.dumps(content)
 
-        def add_account_data_txn(txn, next_id):
-            self._simple_upsert_txn(
-                txn,
+        with self._account_data_id_gen.get_next() as next_id:
+            # no need to lock here as room_account_data has a unique constraint
+            # on (user_id, room_id, account_data_type) so _simple_upsert will
+            # retry if there is a conflict.
+            yield self._simple_upsert(
+                desc="add_room_account_data",
                 table="room_account_data",
                 keyvalues={
                     "user_id": user_id,
@@ -234,19 +237,20 @@ class AccountDataStore(SQLBaseStore):
                 values={
                     "stream_id": next_id,
                     "content": content_json,
-                }
-            )
-            txn.call_after(
-                self._account_data_stream_cache.entity_has_changed,
-                user_id, next_id,
+                },
+                lock=False,
             )
-            txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
-            self._update_max_stream_id(txn, next_id)
 
-        with self._account_data_id_gen.get_next() as next_id:
-            yield self.runInteraction(
-                "add_room_account_data", add_account_data_txn, next_id
-            )
+            # it's theoretically possible for the above to succeed and the
+            # below to fail - in which case we might reuse a stream id on
+            # restart, and the above update might not get propagated. That
+            # doesn't sound any worse than the whole update getting lost,
+            # which is what would happen if we combined the two into one
+            # transaction.
+            yield self._update_max_stream_id(next_id)
+
+            self._account_data_stream_cache.entity_has_changed(user_id, next_id)
+            self.get_account_data_for_user.invalidate((user_id,))
 
         result = self._account_data_id_gen.get_current_token()
         defer.returnValue(result)
@@ -263,9 +267,12 @@ class AccountDataStore(SQLBaseStore):
         """
         content_json = json.dumps(content)
 
-        def add_account_data_txn(txn, next_id):
-            self._simple_upsert_txn(
-                txn,
+        with self._account_data_id_gen.get_next() as next_id:
+            # no need to lock here as account_data has a unique constraint on
+            # (user_id, account_data_type) so _simple_upsert will retry if
+            # there is a conflict.
+            yield self._simple_upsert(
+                desc="add_user_account_data",
                 table="account_data",
                 keyvalues={
                     "user_id": user_id,
@@ -274,40 +281,46 @@ class AccountDataStore(SQLBaseStore):
                 values={
                     "stream_id": next_id,
                     "content": content_json,
-                }
+                },
+                lock=False,
             )
-            txn.call_after(
-                self._account_data_stream_cache.entity_has_changed,
+
+            # it's theoretically possible for the above to succeed and the
+            # below to fail - in which case we might reuse a stream id on
+            # restart, and the above update might not get propagated. That
+            # doesn't sound any worse than the whole update getting lost,
+            # which is what would happen if we combined the two into one
+            # transaction.
+            yield self._update_max_stream_id(next_id)
+
+            self._account_data_stream_cache.entity_has_changed(
                 user_id, next_id,
             )
-            txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
-            txn.call_after(
-                self.get_global_account_data_by_type_for_user.invalidate,
+            self.get_account_data_for_user.invalidate((user_id,))
+            self.get_global_account_data_by_type_for_user.invalidate(
                 (account_data_type, user_id,)
             )
-            self._update_max_stream_id(txn, next_id)
-
-        with self._account_data_id_gen.get_next() as next_id:
-            yield self.runInteraction(
-                "add_user_account_data", add_account_data_txn, next_id
-            )
 
         result = self._account_data_id_gen.get_current_token()
         defer.returnValue(result)
 
-    def _update_max_stream_id(self, txn, next_id):
+    def _update_max_stream_id(self, next_id):
         """Update the max stream_id
 
         Args:
-            txn: The database cursor
             next_id(int): The the revision to advance to.
         """
-        update_max_id_sql = (
-            "UPDATE account_data_max_stream_id"
-            " SET stream_id = ?"
-            " WHERE stream_id < ?"
+        def _update(txn):
+            update_max_id_sql = (
+                "UPDATE account_data_max_stream_id"
+                " SET stream_id = ?"
+                " WHERE stream_id < ?"
+            )
+            txn.execute(update_max_id_sql, (next_id, next_id))
+        return self.runInteraction(
+            "update_account_data_max_stream_id",
+            _update,
         )
-        txn.execute(update_max_id_sql, (next_id, next_id))
 
     @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
     def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 6f235ac051..11a1b942f1 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -85,6 +85,7 @@ class BackgroundUpdateStore(SQLBaseStore):
         self._background_update_performance = {}
         self._background_update_queue = []
         self._background_update_handlers = {}
+        self._all_done = False
 
     @defer.inlineCallbacks
     def start_doing_background_updates(self):
@@ -106,9 +107,41 @@ class BackgroundUpdateStore(SQLBaseStore):
                         "No more background updates to do."
                         " Unscheduling background update task."
                     )
+                    self._all_done = True
                     defer.returnValue(None)
 
     @defer.inlineCallbacks
+    def has_completed_background_updates(self):
+        """Check if all the background updates have completed
+
+        Returns:
+            Deferred[bool]: True if all background updates have completed
+        """
+        # if we've previously determined that there is nothing left to do, that
+        # is easy
+        if self._all_done:
+            defer.returnValue(True)
+
+        # obviously, if we have things in our queue, we're not done.
+        if self._background_update_queue:
+            defer.returnValue(False)
+
+        # otherwise, check if there are updates to be run. This is important,
+        # as we may be running on a worker which doesn't perform the bg updates
+        # itself, but still wants to wait for them to happen.
+        updates = yield self._simple_select_onecol(
+            "background_updates",
+            keyvalues=None,
+            retcol="1",
+            desc="check_background_updates",
+        )
+        if not updates:
+            self._all_done = True
+            defer.returnValue(True)
+
+        defer.returnValue(False)
+
+    @defer.inlineCallbacks
     def do_next_background_update(self, desired_duration_ms):
         """Does some amount of work on the next queued background update
 
@@ -269,7 +302,7 @@ class BackgroundUpdateStore(SQLBaseStore):
             # Sqlite doesn't support concurrent creation of indexes.
             #
             # We don't use partial indices on SQLite as it wasn't introduced
-            # until 3.8, and wheezy has 3.7
+            # until 3.8, and wheezy and CentOS 7 have 3.7
             #
             # We assume that sqlite doesn't give us invalid indices; however
             # we may still end up with the index existing but the
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a6ae79dfad..8a0386c1a4 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -62,3 +62,9 @@ class PostgresEngine(object):
 
     def lock_table(self, txn, table):
         txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
+
+    def get_next_state_group_id(self, txn):
+        """Returns an int that can be used as a new state_group ID
+        """
+        txn.execute("SELECT nextval('state_group_id_seq')")
+        return txn.fetchone()[0]
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py
index 755c9a1f07..60f0fa7fb3 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite3.py
@@ -16,6 +16,7 @@
 from synapse.storage.prepare_database import prepare_database
 
 import struct
+import threading
 
 
 class Sqlite3Engine(object):
@@ -24,6 +25,11 @@ class Sqlite3Engine(object):
     def __init__(self, database_module, database_config):
         self.module = database_module
 
+        # The current max state_group, or None if we haven't looked
+        # in the DB yet.
+        self._current_state_group_id = None
+        self._current_state_group_id_lock = threading.Lock()
+
     def check_database(self, txn):
         pass
 
@@ -43,6 +49,19 @@ class Sqlite3Engine(object):
     def lock_table(self, txn, table):
         return
 
+    def get_next_state_group_id(self, txn):
+        """Returns an int that can be used as a new state_group ID
+        """
+        # We do application locking here since if we're using sqlite then
+        # we are a single process synapse.
+        with self._current_state_group_id_lock:
+            if self._current_state_group_id is None:
+                txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
+                self._current_state_group_id = txn.fetchone()[0]
+
+            self._current_state_group_id += 1
+            return self._current_state_group_id
+
 
 # Following functions taken from: https://github.com/coleifer/peewee
 
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index d08f7571d7..86a7c5920d 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -27,7 +27,7 @@ from synapse.util.logutils import log_function
 from synapse.util.metrics import Measure
 from synapse.api.constants import EventTypes
 from synapse.api.errors import SynapseError
-from synapse.state import resolve_events
+from synapse.state import resolve_events_with_factory
 from synapse.util.caches.descriptors import cached
 from synapse.types import get_domain_from_id
 
@@ -110,7 +110,7 @@ class _EventPeristenceQueue(object):
                 end_item.events_and_contexts.extend(events_and_contexts)
                 return end_item.deferred.observe()
 
-        deferred = ObservableDeferred(defer.Deferred())
+        deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
 
         queue.append(self._EventPersistQueueItem(
             events_and_contexts=events_and_contexts,
@@ -146,18 +146,25 @@ class _EventPeristenceQueue(object):
             try:
                 queue = self._get_drainining_queue(room_id)
                 for item in queue:
+                    # handle_queue_loop runs in the sentinel logcontext, so
+                    # there is no need to preserve_fn when running the
+                    # callbacks on the deferred.
                     try:
                         ret = yield per_item_callback(item)
                         item.deferred.callback(ret)
-                    except Exception as e:
-                        item.deferred.errback(e)
+                    except Exception:
+                        item.deferred.errback()
             finally:
                 queue = self._event_persist_queues.pop(room_id, None)
                 if queue:
                     self._event_persist_queues[room_id] = queue
                 self._currently_persisting_rooms.discard(room_id)
 
-        preserve_fn(handle_queue_loop)()
+        # set handle_queue_loop off on the background. We don't want to
+        # attribute work done in it to the current request, so we drop the
+        # logcontext altogether.
+        with PreserveLoggingContext():
+            handle_queue_loop()
 
     def _get_drainining_queue(self, room_id):
         queue = self._event_persist_queues.setdefault(room_id, deque())
@@ -335,8 +342,20 @@ class EventsStore(SQLBaseStore):
 
                 # NB: Assumes that we are only persisting events for one room
                 # at a time.
+
+                # map room_id->list[event_ids] giving the new forward
+                # extremities in each room
                 new_forward_extremeties = {}
+
+                # map room_id->(type,state_key)->event_id tracking the full
+                # state in each room after adding these events
                 current_state_for_room = {}
+
+                # map room_id->(to_delete, to_insert) where each entry is
+                # a map (type,key)->event_id giving the state delta in each
+                # room
+                state_delta_for_room = {}
+
                 if not backfilled:
                     with Measure(self._clock, "_calculate_state_and_extrem"):
                         # Work out the new "current state" for each room.
@@ -379,11 +398,19 @@ class EventsStore(SQLBaseStore):
                                 if all_single_prev_not_state:
                                     continue
 
-                            state = yield self._calculate_state_delta(
-                                room_id, ev_ctx_rm, new_latest_event_ids
+                            logger.info(
+                                "Calculating state delta for room %s", room_id,
+                            )
+                            current_state = yield self._get_new_state_after_events(
+                                ev_ctx_rm, new_latest_event_ids,
                             )
-                            if state:
-                                current_state_for_room[room_id] = state
+                            if current_state is not None:
+                                current_state_for_room[room_id] = current_state
+                                delta = yield self._calculate_state_delta(
+                                    room_id, current_state,
+                                )
+                                if delta is not None:
+                                    state_delta_for_room[room_id] = delta
 
                 yield self.runInteraction(
                     "persist_events",
@@ -391,7 +418,7 @@ class EventsStore(SQLBaseStore):
                     events_and_contexts=chunk,
                     backfilled=backfilled,
                     delete_existing=delete_existing,
-                    current_state_for_room=current_state_for_room,
+                    state_delta_for_room=state_delta_for_room,
                     new_forward_extremeties=new_forward_extremeties,
                 )
                 persist_event_counter.inc_by(len(chunk))
@@ -408,7 +435,7 @@ class EventsStore(SQLBaseStore):
 
                     event_counter.inc(event.type, origin_type, origin_entity)
 
-                for room_id, (_, _, new_state) in current_state_for_room.iteritems():
+                for room_id, new_state in current_state_for_room.iteritems():
                     self.get_current_state_ids.prefill(
                         (room_id, ), new_state
                     )
@@ -460,20 +487,22 @@ class EventsStore(SQLBaseStore):
         defer.returnValue(new_latest_event_ids)
 
     @defer.inlineCallbacks
-    def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
-        """Calculate the new state deltas for a room.
+    def _get_new_state_after_events(self, events_context, new_latest_event_ids):
+        """Calculate the current state dict after adding some new events to
+        a room
 
-        Assumes that we are only persisting events for one room at a time.
+        Args:
+            events_context (list[(EventBase, EventContext)]):
+                events and contexts which are being added to the room
+
+            new_latest_event_ids (iterable[str]):
+                the new forward extremities for the room.
 
         Returns:
-            3-tuple (to_delete, to_insert, new_state) where both are state dicts,
-            i.e. (type, state_key) -> event_id. `to_delete` are the entries to
-            first be deleted from current_state_events, `to_insert` are entries
-            to insert. `new_state` is the full set of state.
-            May return None if there are no changes to be applied.
+            Deferred[dict[(str,str), str]|None]:
+                None if there are no changes to the room state, or
+                a dict of (type, state_key) -> event_id].
         """
-        # Now we need to work out the different state sets for
-        # each state extremities
         state_sets = []
         state_groups = set()
         missing_event_ids = []
@@ -516,18 +545,23 @@ class EventsStore(SQLBaseStore):
                 state_sets.extend(group_to_state.itervalues())
 
         if not new_latest_event_ids:
-            current_state = {}
+            defer.returnValue({})
         elif was_updated:
             if len(state_sets) == 1:
                 # If there is only one state set, then we know what the current
                 # state is.
-                current_state = state_sets[0]
+                defer.returnValue(state_sets[0])
             else:
                 # We work out the current state by passing the state sets to the
                 # state resolution algorithm. It may ask for some events, including
                 # the events we have yet to persist, so we need a slightly more
                 # complicated event lookup function than simply looking the events
                 # up in the db.
+
+                logger.info(
+                    "Resolving state with %i state sets", len(state_sets),
+                )
+
                 events_map = {ev.event_id: ev for ev, _ in events_context}
 
                 @defer.inlineCallbacks
@@ -550,13 +584,26 @@ class EventsStore(SQLBaseStore):
                         to_return.update(evs)
                     defer.returnValue(to_return)
 
-                current_state = yield resolve_events(
+                current_state = yield resolve_events_with_factory(
                     state_sets,
                     state_map_factory=get_events,
                 )
+                defer.returnValue(current_state)
         else:
             return
 
+    @defer.inlineCallbacks
+    def _calculate_state_delta(self, room_id, current_state):
+        """Calculate the new state deltas for a room.
+
+        Assumes that we are only persisting events for one room at a time.
+
+        Returns:
+            2-tuple (to_delete, to_insert) where both are state dicts,
+            i.e. (type, state_key) -> event_id. `to_delete` are the entries to
+            first be deleted from current_state_events, `to_insert` are entries
+            to insert.
+        """
         existing_state = yield self.get_current_state_ids(room_id)
 
         existing_events = set(existing_state.itervalues())
@@ -576,7 +623,7 @@ class EventsStore(SQLBaseStore):
             if ev_id in events_to_insert
         }
 
-        defer.returnValue((to_delete, to_insert, current_state))
+        defer.returnValue((to_delete, to_insert))
 
     @defer.inlineCallbacks
     def get_event(self, event_id, check_redacted=True,
@@ -636,7 +683,7 @@ class EventsStore(SQLBaseStore):
 
     @log_function
     def _persist_events_txn(self, txn, events_and_contexts, backfilled,
-                            delete_existing=False, current_state_for_room={},
+                            delete_existing=False, state_delta_for_room={},
                             new_forward_extremeties={}):
         """Insert some number of room events into the necessary database tables.
 
@@ -652,7 +699,7 @@ class EventsStore(SQLBaseStore):
             delete_existing (bool): True to purge existing table rows for the
                 events from the database. This is useful when retrying due to
                 IntegrityError.
-            current_state_for_room (dict[str, (list[str], list[str])]):
+            state_delta_for_room (dict[str, (list[str], list[str])]):
                 The current-state delta for each room. For each room, a tuple
                 (to_delete, to_insert), being a list of event ids to be removed
                 from the current state, and a list of event ids to be added to
@@ -664,7 +711,7 @@ class EventsStore(SQLBaseStore):
         """
         max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
 
-        self._update_current_state_txn(txn, current_state_for_room, max_stream_order)
+        self._update_current_state_txn(txn, state_delta_for_room, max_stream_order)
 
         self._update_forward_extremities_txn(
             txn,
@@ -708,9 +755,8 @@ class EventsStore(SQLBaseStore):
             events_and_contexts=events_and_contexts,
         )
 
-        # Insert into the state_groups, state_groups_state, and
-        # event_to_state_groups tables.
-        self._store_mult_state_groups_txn(txn, events_and_contexts)
+        # Insert into event_to_state_groups.
+        self._store_event_state_mappings_txn(txn, events_and_contexts)
 
         # _store_rejected_events_txn filters out any events which were
         # rejected, and returns the filtered list.
@@ -730,7 +776,7 @@ class EventsStore(SQLBaseStore):
 
     def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
         for room_id, current_state_tuple in state_delta_by_room.iteritems():
-                to_delete, to_insert, _ = current_state_tuple
+                to_delete, to_insert = current_state_tuple
                 txn.executemany(
                     "DELETE FROM current_state_events WHERE event_id = ?",
                     [(ev_id,) for ev_id in to_delete.itervalues()],
@@ -945,10 +991,9 @@ class EventsStore(SQLBaseStore):
                 # an outlier in the database. We now have some state at that
                 # so we need to update the state_groups table with that state.
 
-                # insert into the state_group, state_groups_state and
-                # event_to_state_groups tables.
+                # insert into event_to_state_groups.
                 try:
-                    self._store_mult_state_groups_txn(txn, ((event, context),))
+                    self._store_event_state_mappings_txn(txn, ((event, context),))
                 except Exception:
                     logger.exception("")
                     raise
@@ -2018,16 +2063,32 @@ class EventsStore(SQLBaseStore):
             )
         return self.runInteraction("get_all_new_events", get_all_new_events_txn)
 
-    def delete_old_state(self, room_id, topological_ordering):
-        return self.runInteraction(
-            "delete_old_state",
-            self._delete_old_state_txn, room_id, topological_ordering
-        )
+    def purge_history(
+        self, room_id, topological_ordering, delete_local_events,
+    ):
+        """Deletes room history before a certain point
+
+        Args:
+            room_id (str):
 
-    def _delete_old_state_txn(self, txn, room_id, topological_ordering):
-        """Deletes old room state
+            topological_ordering (int):
+                minimum topo ordering to preserve
+
+            delete_local_events (bool):
+                if True, we will delete local events as well as remote ones
+                (instead of just marking them as outliers and deleting their
+                state groups).
         """
 
+        return self.runInteraction(
+            "purge_history",
+            self._purge_history_txn, room_id, topological_ordering,
+            delete_local_events,
+        )
+
+    def _purge_history_txn(
+        self, txn, room_id, topological_ordering, delete_local_events,
+    ):
         # Tables that should be pruned:
         #     event_auth
         #     event_backward_extremities
@@ -2068,7 +2129,7 @@ class EventsStore(SQLBaseStore):
                 400, "topological_ordering is greater than forward extremeties"
             )
 
-        logger.debug("[purge] looking for events to delete")
+        logger.info("[purge] looking for events to delete")
 
         txn.execute(
             "SELECT event_id, state_key FROM events"
@@ -2080,16 +2141,16 @@ class EventsStore(SQLBaseStore):
 
         to_delete = [
             (event_id,) for event_id, state_key in event_rows
-            if state_key is None and not self.hs.is_mine_id(event_id)
+            if state_key is None and (
+                delete_local_events or not self.hs.is_mine_id(event_id)
+            )
         ]
         logger.info(
-            "[purge] found %i events before cutoff, of which %i are remote"
-            " non-state events to delete", len(event_rows), len(to_delete))
-
-        for event_id, state_key in event_rows:
-            txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
+            "[purge] found %i events before cutoff, of which %i can be deleted",
+            len(event_rows), len(to_delete),
+        )
 
-        logger.debug("[purge] Finding new backward extremities")
+        logger.info("[purge] Finding new backward extremities")
 
         # We calculate the new entries for the backward extremeties by finding
         # all events that point to events that are to be purged
@@ -2103,7 +2164,7 @@ class EventsStore(SQLBaseStore):
         )
         new_backwards_extrems = txn.fetchall()
 
-        logger.debug("[purge] replacing backward extremities: %r", new_backwards_extrems)
+        logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
 
         txn.execute(
             "DELETE FROM event_backward_extremities WHERE room_id = ?",
@@ -2119,7 +2180,7 @@ class EventsStore(SQLBaseStore):
             ]
         )
 
-        logger.debug("[purge] finding redundant state groups")
+        logger.info("[purge] finding redundant state groups")
 
         # Get all state groups that are only referenced by events that are
         # to be deleted.
@@ -2136,15 +2197,15 @@ class EventsStore(SQLBaseStore):
         )
 
         state_rows = txn.fetchall()
-        logger.debug("[purge] found %i redundant state groups", len(state_rows))
+        logger.info("[purge] found %i redundant state groups", len(state_rows))
 
         # make a set of the redundant state groups, so that we can look them up
         # efficiently
         state_groups_to_delete = set([sg for sg, in state_rows])
 
         # Now we get all the state groups that rely on these state groups
-        logger.debug("[purge] finding state groups which depend on redundant"
-                     " state groups")
+        logger.info("[purge] finding state groups which depend on redundant"
+                    " state groups")
         remaining_state_groups = []
         for i in xrange(0, len(state_rows), 100):
             chunk = [sg for sg, in state_rows[i:i + 100]]
@@ -2169,7 +2230,7 @@ class EventsStore(SQLBaseStore):
         # Now we turn the state groups that reference to-be-deleted state
         # groups to non delta versions.
         for sg in remaining_state_groups:
-            logger.debug("[purge] de-delta-ing remaining state group %s", sg)
+            logger.info("[purge] de-delta-ing remaining state group %s", sg)
             curr_state = self._get_state_groups_from_groups_txn(
                 txn, [sg], types=None
             )
@@ -2206,7 +2267,7 @@ class EventsStore(SQLBaseStore):
                 ],
             )
 
-        logger.debug("[purge] removing redundant state groups")
+        logger.info("[purge] removing redundant state groups")
         txn.executemany(
             "DELETE FROM state_groups_state WHERE state_group = ?",
             state_rows
@@ -2216,18 +2277,15 @@ class EventsStore(SQLBaseStore):
             state_rows
         )
 
-        # Delete all non-state
-        logger.debug("[purge] removing events from event_to_state_groups")
+        logger.info("[purge] removing events from event_to_state_groups")
         txn.executemany(
             "DELETE FROM event_to_state_groups WHERE event_id = ?",
             [(event_id,) for event_id, _ in event_rows]
         )
-
-        logger.debug("[purge] updating room_depth")
-        txn.execute(
-            "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
-            (topological_ordering, room_id,)
-        )
+        for event_id, _ in event_rows:
+            txn.call_after(self._get_state_group_for_event.invalidate, (
+                event_id,
+            ))
 
         # Delete all remote non-state events
         for table in (
@@ -2245,7 +2303,8 @@ class EventsStore(SQLBaseStore):
             "event_signatures",
             "rejections",
         ):
-            logger.debug("[purge] removing remote non-state events from %s", table)
+            logger.info("[purge] removing remote non-state events from %s",
+                        table)
 
             txn.executemany(
                 "DELETE FROM %s WHERE event_id = ?" % (table,),
@@ -2253,16 +2312,30 @@ class EventsStore(SQLBaseStore):
             )
 
         # Mark all state and own events as outliers
-        logger.debug("[purge] marking remaining events as outliers")
+        logger.info("[purge] marking remaining events as outliers")
         txn.executemany(
             "UPDATE events SET outlier = ?"
             " WHERE event_id = ?",
             [
                 (True, event_id,) for event_id, state_key in event_rows
-                if state_key is not None or self.hs.is_mine_id(event_id)
+                if state_key is not None or (
+                    not delete_local_events and self.hs.is_mine_id(event_id)
+                )
             ]
         )
 
+        # synapse tries to take out an exclusive lock on room_depth whenever it
+        # persists events (because upsert), and once we run this update, we
+        # will block that for the rest of our transaction.
+        #
+        # So, let's stick it at the end so that we don't block event
+        # persistence.
+        logger.info("[purge] updating room_depth")
+        txn.execute(
+            "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
+            (topological_ordering, room_id,)
+        )
+
         logger.info("[purge] done")
 
     @defer.inlineCallbacks
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index 52e5cdad70..e6cdbb0545 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -12,15 +12,22 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from synapse.storage.background_updates import BackgroundUpdateStore
 
-from ._base import SQLBaseStore
 
-
-class MediaRepositoryStore(SQLBaseStore):
+class MediaRepositoryStore(BackgroundUpdateStore):
     """Persistence for attachments and avatars"""
 
-    def get_default_thumbnails(self, top_level_type, sub_type):
-        return []
+    def __init__(self, db_conn, hs):
+        super(MediaRepositoryStore, self).__init__(db_conn, hs)
+
+        self.register_background_index_update(
+            update_name='local_media_repository_url_idx',
+            index_name='local_media_repository_url_idx',
+            table='local_media_repository',
+            columns=['created_ts'],
+            where_clause='url_cache IS NOT NULL',
+        )
 
     def get_local_media(self, media_id):
         """Get the metadata for a local piece of media
@@ -166,7 +173,14 @@ class MediaRepositoryStore(SQLBaseStore):
             desc="store_cached_remote_media",
         )
 
-    def update_cached_last_access_time(self, origin_id_tuples, time_ts):
+    def update_cached_last_access_time(self, local_media, remote_media, time_ms):
+        """Updates the last access time of the given media
+
+        Args:
+            local_media (iterable[str]): Set of media_ids
+            remote_media (iterable[(str, str)]): Set of (server_name, media_id)
+            time_ms: Current time in milliseconds
+        """
         def update_cache_txn(txn):
             sql = (
                 "UPDATE remote_media_cache SET last_access_ts = ?"
@@ -174,8 +188,18 @@ class MediaRepositoryStore(SQLBaseStore):
             )
 
             txn.executemany(sql, (
-                (time_ts, media_origin, media_id)
-                for media_origin, media_id in origin_id_tuples
+                (time_ms, media_origin, media_id)
+                for media_origin, media_id in remote_media
+            ))
+
+            sql = (
+                "UPDATE local_media_repository SET last_access_ts = ?"
+                " WHERE media_id = ?"
+            )
+
+            txn.executemany(sql, (
+                (time_ms, media_id)
+                for media_id in local_media
             ))
 
         return self.runInteraction("update_cached_last_access_time", update_cache_txn)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index d1691bbac2..c845a0cec5 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 46
+SCHEMA_VERSION = 47
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index beea3102fc..ec02e73bc2 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -15,6 +15,9 @@
 
 from twisted.internet import defer
 
+from synapse.storage.roommember import ProfileInfo
+from synapse.api.errors import StoreError
+
 from ._base import SQLBaseStore
 
 
@@ -26,6 +29,30 @@ class ProfileStore(SQLBaseStore):
             desc="create_profile",
         )
 
+    @defer.inlineCallbacks
+    def get_profileinfo(self, user_localpart):
+        try:
+            profile = yield self._simple_select_one(
+                table="profiles",
+                keyvalues={"user_id": user_localpart},
+                retcols=("displayname", "avatar_url"),
+                desc="get_profileinfo",
+            )
+        except StoreError as e:
+            if e.code == 404:
+                # no match
+                defer.returnValue(ProfileInfo(None, None))
+                return
+            else:
+                raise
+
+        defer.returnValue(
+            ProfileInfo(
+                avatar_url=profile['avatar_url'],
+                display_name=profile['displayname'],
+            )
+        )
+
     def get_profile_displayname(self, user_localpart):
         return self._simple_select_one_onecol(
             table="profiles",
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 8b9544c209..3aa810981f 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -254,8 +254,8 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
                 If None, tokens associated with any device (or no device) will
                 be deleted
         Returns:
-            defer.Deferred[list[str, str|None]]: a list of the deleted tokens
-                and device IDs
+            defer.Deferred[list[str, int, str|None, int]]: a list of
+                (token, token id, device id) for each of the deleted tokens
         """
         def f(txn):
             keyvalues = {
@@ -272,12 +272,12 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
                 values.append(except_token_id)
 
             txn.execute(
-                "SELECT token, device_id FROM access_tokens WHERE %s" % where_clause,
+                "SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause,
                 values
             )
-            tokens_and_devices = [(r[0], r[1]) for r in txn]
+            tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
 
-            for token, _ in tokens_and_devices:
+            for token, _, _ in tokens_and_devices:
                 self._invalidate_cache_and_stream(
                     txn, self.get_user_by_access_token, (token,)
                 )
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 23688430b7..fff6652e05 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -16,11 +16,9 @@
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
+from synapse.storage.search import SearchStore
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
-from ._base import SQLBaseStore
-from .engines import PostgresEngine, Sqlite3Engine
-
 import collections
 import logging
 import ujson as json
@@ -40,7 +38,7 @@ RatelimitOverride = collections.namedtuple(
 )
 
 
-class RoomStore(SQLBaseStore):
+class RoomStore(SearchStore):
 
     @defer.inlineCallbacks
     def store_room(self, room_id, room_creator_user_id, is_public):
@@ -263,8 +261,8 @@ class RoomStore(SQLBaseStore):
                 },
             )
 
-            self._store_event_search_txn(
-                txn, event, "content.topic", event.content["topic"]
+            self.store_event_search_txn(
+                txn, event, "content.topic", event.content["topic"],
             )
 
     def _store_room_name_txn(self, txn, event):
@@ -279,14 +277,14 @@ class RoomStore(SQLBaseStore):
                 }
             )
 
-            self._store_event_search_txn(
-                txn, event, "content.name", event.content["name"]
+            self.store_event_search_txn(
+                txn, event, "content.name", event.content["name"],
             )
 
     def _store_room_message_txn(self, txn, event):
         if hasattr(event, "content") and "body" in event.content:
-            self._store_event_search_txn(
-                txn, event, "content.body", event.content["body"]
+            self.store_event_search_txn(
+                txn, event, "content.body", event.content["body"],
             )
 
     def _store_history_visibility_txn(self, txn, event):
@@ -308,31 +306,6 @@ class RoomStore(SQLBaseStore):
                 event.content[key]
             ))
 
-    def _store_event_search_txn(self, txn, event, key, value):
-        if isinstance(self.database_engine, PostgresEngine):
-            sql = (
-                "INSERT INTO event_search"
-                " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
-                " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
-            )
-            txn.execute(
-                sql,
-                (
-                    event.event_id, event.room_id, key, value,
-                    event.internal_metadata.stream_ordering,
-                    event.origin_server_ts,
-                )
-            )
-        elif isinstance(self.database_engine, Sqlite3Engine):
-            sql = (
-                "INSERT INTO event_search (event_id, room_id, key, value)"
-                " VALUES (?,?,?,?)"
-            )
-            txn.execute(sql, (event.event_id, event.room_id, key, value,))
-        else:
-            # This should be unreachable.
-            raise Exception("Unrecognized database engine")
-
     def add_event_report(self, room_id, event_id, user_id, reason, content,
                          received_ts):
         next_id = self._event_reports_id_gen.get_next()
@@ -533,73 +506,114 @@ class RoomStore(SQLBaseStore):
         )
         self.is_room_blocked.invalidate((room_id,))
 
+    def get_media_mxcs_in_room(self, room_id):
+        """Retrieves all the local and remote media MXC URIs in a given room
+
+        Args:
+            room_id (str)
+
+        Returns:
+            The local and remote media as a lists of tuples where the key is
+            the hostname and the value is the media ID.
+        """
+        def _get_media_mxcs_in_room_txn(txn):
+            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+            local_media_mxcs = []
+            remote_media_mxcs = []
+
+            # Convert the IDs to MXC URIs
+            for media_id in local_mxcs:
+                local_media_mxcs.append("mxc://%s/%s" % (self.hostname, media_id))
+            for hostname, media_id in remote_mxcs:
+                remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
+
+            return local_media_mxcs, remote_media_mxcs
+        return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
+
     def quarantine_media_ids_in_room(self, room_id, quarantined_by):
         """For a room loops through all events with media and quarantines
         the associated media
         """
-        def _get_media_ids_in_room(txn):
-            mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+        def _quarantine_media_in_room_txn(txn):
+            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+            total_media_quarantined = 0
 
-            next_token = self.get_current_events_token() + 1
+            # Now update all the tables to set the quarantined_by flag
 
-            total_media_quarantined = 0
+            txn.executemany("""
+                UPDATE local_media_repository
+                SET quarantined_by = ?
+                WHERE media_id = ?
+            """, ((quarantined_by, media_id) for media_id in local_mxcs))
 
-            while next_token:
-                sql = """
-                    SELECT stream_ordering, content FROM events
-                    WHERE room_id = ?
-                        AND stream_ordering < ?
-                        AND contains_url = ? AND outlier = ?
-                    ORDER BY stream_ordering DESC
-                    LIMIT ?
+            txn.executemany(
                 """
-                txn.execute(sql, (room_id, next_token, True, False, 100))
-
-                next_token = None
-                local_media_mxcs = []
-                remote_media_mxcs = []
-                for stream_ordering, content_json in txn:
-                    next_token = stream_ordering
-                    content = json.loads(content_json)
-
-                    content_url = content.get("url")
-                    thumbnail_url = content.get("info", {}).get("thumbnail_url")
-
-                    for url in (content_url, thumbnail_url):
-                        if not url:
-                            continue
-                        matches = mxc_re.match(url)
-                        if matches:
-                            hostname = matches.group(1)
-                            media_id = matches.group(2)
-                            if hostname == self.hostname:
-                                local_media_mxcs.append(media_id)
-                            else:
-                                remote_media_mxcs.append((hostname, media_id))
-
-                # Now update all the tables to set the quarantined_by flag
-
-                txn.executemany("""
-                    UPDATE local_media_repository
+                    UPDATE remote_media_cache
                     SET quarantined_by = ?
-                    WHERE media_id = ?
-                """, ((quarantined_by, media_id) for media_id in local_media_mxcs))
-
-                txn.executemany(
-                    """
-                        UPDATE remote_media_cache
-                        SET quarantined_by = ?
-                        WHERE media_origin AND media_id = ?
-                    """,
-                    (
-                        (quarantined_by, origin, media_id)
-                        for origin, media_id in remote_media_mxcs
-                    )
+                    WHERE media_origin = ? AND media_id = ?
+                """,
+                (
+                    (quarantined_by, origin, media_id)
+                    for origin, media_id in remote_mxcs
                 )
+            )
 
-                total_media_quarantined += len(local_media_mxcs)
-                total_media_quarantined += len(remote_media_mxcs)
+            total_media_quarantined += len(local_mxcs)
+            total_media_quarantined += len(remote_mxcs)
 
             return total_media_quarantined
 
-        return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room)
+        return self.runInteraction(
+            "quarantine_media_in_room",
+            _quarantine_media_in_room_txn,
+        )
+
+    def _get_media_mxcs_in_room_txn(self, txn, room_id):
+        """Retrieves all the local and remote media MXC URIs in a given room
+
+        Args:
+            txn (cursor)
+            room_id (str)
+
+        Returns:
+            The local and remote media as a lists of tuples where the key is
+            the hostname and the value is the media ID.
+        """
+        mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+
+        next_token = self.get_current_events_token() + 1
+        local_media_mxcs = []
+        remote_media_mxcs = []
+
+        while next_token:
+            sql = """
+                SELECT stream_ordering, content FROM events
+                WHERE room_id = ?
+                    AND stream_ordering < ?
+                    AND contains_url = ? AND outlier = ?
+                ORDER BY stream_ordering DESC
+                LIMIT ?
+            """
+            txn.execute(sql, (room_id, next_token, True, False, 100))
+
+            next_token = None
+            for stream_ordering, content_json in txn:
+                next_token = stream_ordering
+                content = json.loads(content_json)
+
+                content_url = content.get("url")
+                thumbnail_url = content.get("info", {}).get("thumbnail_url")
+
+                for url in (content_url, thumbnail_url):
+                    if not url:
+                        continue
+                    matches = mxc_re.match(url)
+                    if matches:
+                        hostname = matches.group(1)
+                        media_id = matches.group(2)
+                        if hostname == self.hostname:
+                            local_media_mxcs.append(media_id)
+                        else:
+                            remote_media_mxcs.append((hostname, media_id))
+
+        return local_media_mxcs, remote_media_mxcs
diff --git a/synapse/storage/schema/delta/44/expire_url_cache.sql b/synapse/storage/schema/delta/44/expire_url_cache.sql
index e2b775f038..b12f9b2ebf 100644
--- a/synapse/storage/schema/delta/44/expire_url_cache.sql
+++ b/synapse/storage/schema/delta/44/expire_url_cache.sql
@@ -13,7 +13,10 @@
  * limitations under the License.
  */
 
-CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
+-- this didn't work on SQLite 3.7 (because of lack of partial indexes), so was
+-- removed and replaced with 46/local_media_repository_url_idx.sql.
+--
+-- CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
 
 -- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support
 -- indices on expressions until 3.9.
diff --git a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql
new file mode 100644
index 0000000000..bbfc7f5d1a
--- /dev/null
+++ b/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql
@@ -0,0 +1,24 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- register a background update which will recreate the
+-- local_media_repository_url_idx index.
+--
+-- We do this as a bg update not because it is a particularly onerous
+-- operation, but because we'd like it to be a partial index if possible, and
+-- the background_index_update code will understand whether we are on
+-- postgres or sqlite and behave accordingly.
+INSERT INTO background_updates (update_name, progress_json) VALUES
+    ('local_media_repository_url_idx', '{}');
diff --git a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql
new file mode 100644
index 0000000000..cb0d5a2576
--- /dev/null
+++ b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql
@@ -0,0 +1,35 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- change the user_directory table to also cover global local user profiles
+-- rather than just profiles within specific rooms.
+
+CREATE TABLE user_directory2 (
+    user_id TEXT NOT NULL,
+    room_id TEXT,
+    display_name TEXT,
+    avatar_url TEXT
+);
+
+INSERT INTO user_directory2(user_id, room_id, display_name, avatar_url)
+    SELECT user_id, room_id, display_name, avatar_url from user_directory;
+
+DROP TABLE user_directory;
+ALTER TABLE user_directory2 RENAME TO user_directory;
+
+-- create indexes after doing the inserts because that's more efficient.
+-- it also means we can give it the same name as the old one without renaming.
+CREATE INDEX user_directory_room_idx ON user_directory(room_id);
+CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id);
diff --git a/synapse/storage/schema/delta/47/last_access_media.sql b/synapse/storage/schema/delta/47/last_access_media.sql
new file mode 100644
index 0000000000..f505fb22b5
--- /dev/null
+++ b/synapse/storage/schema/delta/47/last_access_media.sql
@@ -0,0 +1,16 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT;
diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/schema/delta/47/state_group_seq.py
new file mode 100644
index 0000000000..f6766501d2
--- /dev/null
+++ b/synapse/storage/schema/delta/47/state_group_seq.py
@@ -0,0 +1,37 @@
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.engines import PostgresEngine
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+    if isinstance(database_engine, PostgresEngine):
+        # if we already have some state groups, we want to start making new
+        # ones with a higher id.
+        cur.execute("SELECT max(id) FROM state_groups")
+        row = cur.fetchone()
+
+        if row[0] is None:
+            start_val = 1
+        else:
+            start_val = row[0] + 1
+
+        cur.execute(
+            "CREATE SEQUENCE state_group_id_seq START WITH %s",
+            (start_val, ),
+        )
+
+
+def run_upgrade(*args, **kwargs):
+    pass
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index 479b04c636..f1ac9ba0fd 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -13,19 +13,25 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from collections import namedtuple
+import logging
+import re
+import ujson as json
+
 from twisted.internet import defer
 
 from .background_updates import BackgroundUpdateStore
 from synapse.api.errors import SynapseError
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
-import logging
-import re
-import ujson as json
-
 
 logger = logging.getLogger(__name__)
 
+SearchEntry = namedtuple('SearchEntry', [
+    'key', 'value', 'event_id', 'room_id', 'stream_ordering',
+    'origin_server_ts',
+])
+
 
 class SearchStore(BackgroundUpdateStore):
 
@@ -49,16 +55,17 @@ class SearchStore(BackgroundUpdateStore):
 
     @defer.inlineCallbacks
     def _background_reindex_search(self, progress, batch_size):
+        # we work through the events table from highest stream id to lowest
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
 
-        INSERT_CLUMP_SIZE = 1000
         TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
 
         def reindex_search_txn(txn):
             sql = (
-                "SELECT stream_ordering, event_id, room_id, type, content FROM events"
+                "SELECT stream_ordering, event_id, room_id, type, content, "
+                " origin_server_ts FROM events"
                 " WHERE ? <= stream_ordering AND stream_ordering < ?"
                 " AND (%s)"
                 " ORDER BY stream_ordering DESC"
@@ -67,6 +74,10 @@ class SearchStore(BackgroundUpdateStore):
 
             txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
 
+            # we could stream straight from the results into
+            # store_search_entries_txn with a generator function, but that
+            # would mean having two cursors open on the database at once.
+            # Instead we just build a list of results.
             rows = self.cursor_to_dict(txn)
             if not rows:
                 return 0
@@ -79,6 +90,8 @@ class SearchStore(BackgroundUpdateStore):
                     event_id = row["event_id"]
                     room_id = row["room_id"]
                     etype = row["type"]
+                    stream_ordering = row["stream_ordering"]
+                    origin_server_ts = row["origin_server_ts"]
                     try:
                         content = json.loads(row["content"])
                     except Exception:
@@ -93,6 +106,8 @@ class SearchStore(BackgroundUpdateStore):
                     elif etype == "m.room.name":
                         key = "content.name"
                         value = content["name"]
+                    else:
+                        raise Exception("unexpected event type %s" % etype)
                 except (KeyError, AttributeError):
                     # If the event is missing a necessary field then
                     # skip over it.
@@ -103,25 +118,16 @@ class SearchStore(BackgroundUpdateStore):
                     # then skip over it
                     continue
 
-                event_search_rows.append((event_id, room_id, key, value))
+                event_search_rows.append(SearchEntry(
+                    key=key,
+                    value=value,
+                    event_id=event_id,
+                    room_id=room_id,
+                    stream_ordering=stream_ordering,
+                    origin_server_ts=origin_server_ts,
+                ))
 
-            if isinstance(self.database_engine, PostgresEngine):
-                sql = (
-                    "INSERT INTO event_search (event_id, room_id, key, vector)"
-                    " VALUES (?,?,?,to_tsvector('english', ?))"
-                )
-            elif isinstance(self.database_engine, Sqlite3Engine):
-                sql = (
-                    "INSERT INTO event_search (event_id, room_id, key, value)"
-                    " VALUES (?,?,?,?)"
-                )
-            else:
-                # This should be unreachable.
-                raise Exception("Unrecognized database engine")
-
-            for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE):
-                clump = event_search_rows[index:index + INSERT_CLUMP_SIZE]
-                txn.executemany(sql, clump)
+            self.store_search_entries_txn(txn, event_search_rows)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
@@ -242,6 +248,62 @@ class SearchStore(BackgroundUpdateStore):
 
         defer.returnValue(num_rows)
 
+    def store_event_search_txn(self, txn, event, key, value):
+        """Add event to the search table
+
+        Args:
+            txn (cursor):
+            event (EventBase):
+            key (str):
+            value (str):
+        """
+        self.store_search_entries_txn(
+            txn,
+            (SearchEntry(
+                key=key,
+                value=value,
+                event_id=event.event_id,
+                room_id=event.room_id,
+                stream_ordering=event.internal_metadata.stream_ordering,
+                origin_server_ts=event.origin_server_ts,
+            ),),
+        )
+
+    def store_search_entries_txn(self, txn, entries):
+        """Add entries to the search table
+
+        Args:
+            txn (cursor):
+            entries (iterable[SearchEntry]):
+                entries to be added to the table
+        """
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = (
+                "INSERT INTO event_search"
+                " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
+                " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
+            )
+
+            args = ((
+                entry.event_id, entry.room_id, entry.key, entry.value,
+                entry.stream_ordering, entry.origin_server_ts,
+            ) for entry in entries)
+
+            txn.executemany(sql, args)
+        elif isinstance(self.database_engine, Sqlite3Engine):
+            sql = (
+                "INSERT INTO event_search (event_id, room_id, key, value)"
+                " VALUES (?,?,?,?)"
+            )
+            args = ((
+                entry.event_id, entry.room_id, entry.key, entry.value,
+            ) for entry in entries)
+
+            txn.executemany(sql, args)
+        else:
+            # This should be unreachable.
+            raise Exception("Unrecognized database engine")
+
     @defer.inlineCallbacks
     def search_msgs(self, room_ids, search_term, keys):
         """Performs a full text search over events with given keys.
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 360e3e4355..adb48df73e 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -42,11 +42,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
         return len(self.delta_ids) if self.delta_ids else 0
 
 
-class StateGroupReadStore(SQLBaseStore):
-    """The read-only parts of StateGroupStore
-
-    None of these functions write to the state tables, so are suitable for
-    including in the SlavedStores.
+class StateGroupWorkerStore(SQLBaseStore):
+    """The parts of StateGroupStore that can be called from workers.
     """
 
     STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
@@ -54,7 +51,7 @@ class StateGroupReadStore(SQLBaseStore):
     CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
 
     def __init__(self, db_conn, hs):
-        super(StateGroupReadStore, self).__init__(db_conn, hs)
+        super(StateGroupWorkerStore, self).__init__(db_conn, hs)
 
         self._state_group_cache = DictionaryCache(
             "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
@@ -549,116 +546,66 @@ class StateGroupReadStore(SQLBaseStore):
 
         defer.returnValue(results)
 
+    def store_state_group(self, event_id, room_id, prev_group, delta_ids,
+                          current_state_ids):
+        """Store a new set of state, returning a newly assigned state group.
 
-class StateStore(StateGroupReadStore, BackgroundUpdateStore):
-    """ Keeps track of the state at a given event.
-
-    This is done by the concept of `state groups`. Every event is a assigned
-    a state group (identified by an arbitrary string), which references a
-    collection of state events. The current state of an event is then the
-    collection of state events referenced by the event's state group.
-
-    Hence, every change in the current state causes a new state group to be
-    generated. However, if no change happens (e.g., if we get a message event
-    with only one parent it inherits the state group from its parent.)
-
-    There are three tables:
-      * `state_groups`: Stores group name, first event with in the group and
-        room id.
-      * `event_to_state_groups`: Maps events to state groups.
-      * `state_groups_state`: Maps state group to state events.
-    """
-
-    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
-    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
-    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
-
-    def __init__(self, db_conn, hs):
-        super(StateStore, self).__init__(db_conn, hs)
-        self.register_background_update_handler(
-            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
-            self._background_deduplicate_state,
-        )
-        self.register_background_update_handler(
-            self.STATE_GROUP_INDEX_UPDATE_NAME,
-            self._background_index_state,
-        )
-        self.register_background_index_update(
-            self.CURRENT_STATE_INDEX_UPDATE_NAME,
-            index_name="current_state_events_member_index",
-            table="current_state_events",
-            columns=["state_key"],
-            where_clause="type='m.room.member'",
-        )
-
-    def _have_persisted_state_group_txn(self, txn, state_group):
-        txn.execute(
-            "SELECT count(*) FROM state_groups WHERE id = ?",
-            (state_group,)
-        )
-        row = txn.fetchone()
-        return row and row[0]
-
-    def _store_mult_state_groups_txn(self, txn, events_and_contexts):
-        state_groups = {}
-        for event, context in events_and_contexts:
-            if event.internal_metadata.is_outlier():
-                continue
+        Args:
+            event_id (str): The event ID for which the state was calculated
+            room_id (str)
+            prev_group (int|None): A previous state group for the room, optional.
+            delta_ids (dict|None): The delta between state at `prev_group` and
+                `current_state_ids`, if `prev_group` was given. Same format as
+                `current_state_ids`.
+            current_state_ids (dict): The state to store. Map of (type, state_key)
+                to event_id.
 
-            if context.current_state_ids is None:
+        Returns:
+            Deferred[int]: The state group ID
+        """
+        def _store_state_group_txn(txn):
+            if current_state_ids is None:
                 # AFAIK, this can never happen
-                logger.error(
-                    "Non-outlier event %s had current_state_ids==None",
-                    event.event_id)
-                continue
+                raise Exception("current_state_ids cannot be None")
 
-            # if the event was rejected, just give it the same state as its
-            # predecessor.
-            if context.rejected:
-                state_groups[event.event_id] = context.prev_group
-                continue
-
-            state_groups[event.event_id] = context.state_group
-
-            if self._have_persisted_state_group_txn(txn, context.state_group):
-                continue
+            state_group = self.database_engine.get_next_state_group_id(txn)
 
             self._simple_insert_txn(
                 txn,
                 table="state_groups",
                 values={
-                    "id": context.state_group,
-                    "room_id": event.room_id,
-                    "event_id": event.event_id,
+                    "id": state_group,
+                    "room_id": room_id,
+                    "event_id": event_id,
                 },
             )
 
             # We persist as a delta if we can, while also ensuring the chain
             # of deltas isn't tooo long, as otherwise read performance degrades.
-            if context.prev_group:
+            if prev_group:
                 is_in_db = self._simple_select_one_onecol_txn(
                     txn,
                     table="state_groups",
-                    keyvalues={"id": context.prev_group},
+                    keyvalues={"id": prev_group},
                     retcol="id",
                     allow_none=True,
                 )
                 if not is_in_db:
                     raise Exception(
                         "Trying to persist state with unpersisted prev_group: %r"
-                        % (context.prev_group,)
+                        % (prev_group,)
                     )
 
                 potential_hops = self._count_state_group_hops_txn(
-                    txn, context.prev_group
+                    txn, prev_group
                 )
-            if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+            if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
                 self._simple_insert_txn(
                     txn,
                     table="state_group_edges",
                     values={
-                        "state_group": context.state_group,
-                        "prev_state_group": context.prev_group,
+                        "state_group": state_group,
+                        "prev_state_group": prev_group,
                     },
                 )
 
@@ -667,13 +614,13 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
                     table="state_groups_state",
                     values=[
                         {
-                            "state_group": context.state_group,
-                            "room_id": event.room_id,
+                            "state_group": state_group,
+                            "room_id": room_id,
                             "type": key[0],
                             "state_key": key[1],
                             "event_id": state_id,
                         }
-                        for key, state_id in context.delta_ids.iteritems()
+                        for key, state_id in delta_ids.iteritems()
                     ],
                 )
             else:
@@ -682,13 +629,13 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
                     table="state_groups_state",
                     values=[
                         {
-                            "state_group": context.state_group,
-                            "room_id": event.room_id,
+                            "state_group": state_group,
+                            "room_id": room_id,
                             "type": key[0],
                             "state_key": key[1],
                             "event_id": state_id,
                         }
-                        for key, state_id in context.current_state_ids.iteritems()
+                        for key, state_id in current_state_ids.iteritems()
                     ],
                 )
 
@@ -699,11 +646,71 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
             txn.call_after(
                 self._state_group_cache.update,
                 self._state_group_cache.sequence,
-                key=context.state_group,
-                value=dict(context.current_state_ids),
+                key=state_group,
+                value=dict(current_state_ids),
                 full=True,
             )
 
+            return state_group
+
+        return self.runInteraction("store_state_group", _store_state_group_txn)
+
+
+class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
+    """ Keeps track of the state at a given event.
+
+    This is done by the concept of `state groups`. Every event is a assigned
+    a state group (identified by an arbitrary string), which references a
+    collection of state events. The current state of an event is then the
+    collection of state events referenced by the event's state group.
+
+    Hence, every change in the current state causes a new state group to be
+    generated. However, if no change happens (e.g., if we get a message event
+    with only one parent it inherits the state group from its parent.)
+
+    There are three tables:
+      * `state_groups`: Stores group name, first event with in the group and
+        room id.
+      * `event_to_state_groups`: Maps events to state groups.
+      * `state_groups_state`: Maps state group to state events.
+    """
+
+    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
+    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
+    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
+
+    def __init__(self, db_conn, hs):
+        super(StateStore, self).__init__(db_conn, hs)
+        self.register_background_update_handler(
+            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
+            self._background_deduplicate_state,
+        )
+        self.register_background_update_handler(
+            self.STATE_GROUP_INDEX_UPDATE_NAME,
+            self._background_index_state,
+        )
+        self.register_background_index_update(
+            self.CURRENT_STATE_INDEX_UPDATE_NAME,
+            index_name="current_state_events_member_index",
+            table="current_state_events",
+            columns=["state_key"],
+            where_clause="type='m.room.member'",
+        )
+
+    def _store_event_state_mappings_txn(self, txn, events_and_contexts):
+        state_groups = {}
+        for event, context in events_and_contexts:
+            if event.internal_metadata.is_outlier():
+                continue
+
+            # if the event was rejected, just give it the same state as its
+            # predecessor.
+            if context.rejected:
+                state_groups[event.event_id] = context.prev_group
+                continue
+
+            state_groups[event.event_id] = context.state_group
+
         self._simple_insert_many_txn(
             txn,
             table="event_to_state_groups",
@@ -763,9 +770,6 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
 
             return count
 
-    def get_next_state_group(self):
-        return self._state_groups_id_gen.get_next()
-
     @defer.inlineCallbacks
     def _background_deduplicate_state(self, progress, batch_size):
         """This background update will slowly deduplicate state by reencoding
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index 5dc5b9582a..dfdcbb3181 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -164,7 +164,7 @@ class UserDirectoryStore(SQLBaseStore):
             )
 
             if isinstance(self.database_engine, PostgresEngine):
-                # We weight the loclpart most highly, then display name and finally
+                # We weight the localpart most highly, then display name and finally
                 # server name
                 if new_entry:
                     sql = """
@@ -317,6 +317,16 @@ class UserDirectoryStore(SQLBaseStore):
         rows = yield self._execute("get_all_rooms", None, sql)
         defer.returnValue([room_id for room_id, in rows])
 
+    @defer.inlineCallbacks
+    def get_all_local_users(self):
+        """Get all local users
+        """
+        sql = """
+            SELECT name FROM users
+        """
+        rows = yield self._execute("get_all_local_users", None, sql)
+        defer.returnValue([name for name, in rows])
+
     def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
         """Insert entries into the users_who_share_rooms table. The first
         user should be a local user.
@@ -629,6 +639,25 @@ class UserDirectoryStore(SQLBaseStore):
                     ]
                 }
         """
+
+        if self.hs.config.user_directory_search_all_users:
+            # make s.user_id null to keep the ordering algorithm happy
+            join_clause = """
+                CROSS JOIN (SELECT NULL as user_id) AS s
+            """
+            join_args = ()
+            where_clause = "1=1"
+        else:
+            join_clause = """
+                LEFT JOIN users_in_public_rooms AS p USING (user_id)
+                LEFT JOIN (
+                    SELECT other_user_id AS user_id FROM users_who_share_rooms
+                    WHERE user_id = ? AND share_private
+                ) AS s USING (user_id)
+            """
+            join_args = (user_id,)
+            where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
+
         if isinstance(self.database_engine, PostgresEngine):
             full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
 
@@ -641,13 +670,9 @@ class UserDirectoryStore(SQLBaseStore):
                 SELECT d.user_id, display_name, avatar_url
                 FROM user_directory_search
                 INNER JOIN user_directory AS d USING (user_id)
-                LEFT JOIN users_in_public_rooms AS p USING (user_id)
-                LEFT JOIN (
-                    SELECT other_user_id AS user_id FROM users_who_share_rooms
-                    WHERE user_id = ? AND share_private
-                ) AS s USING (user_id)
+                %s
                 WHERE
-                    (s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
+                    %s
                     AND vector @@ to_tsquery('english', ?)
                 ORDER BY
                     (CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
@@ -671,8 +696,8 @@ class UserDirectoryStore(SQLBaseStore):
                     display_name IS NULL,
                     avatar_url IS NULL
                 LIMIT ?
-            """
-            args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
+            """ % (join_clause, where_clause)
+            args = join_args + (full_query, exact_query, prefix_query, limit + 1,)
         elif isinstance(self.database_engine, Sqlite3Engine):
             search_query = _parse_query_sqlite(search_term)
 
@@ -680,21 +705,17 @@ class UserDirectoryStore(SQLBaseStore):
                 SELECT d.user_id, display_name, avatar_url
                 FROM user_directory_search
                 INNER JOIN user_directory AS d USING (user_id)
-                LEFT JOIN users_in_public_rooms AS p USING (user_id)
-                LEFT JOIN (
-                    SELECT other_user_id AS user_id FROM users_who_share_rooms
-                    WHERE user_id = ? AND share_private
-                ) AS s USING (user_id)
+                %s
                 WHERE
-                    (s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
+                    %s
                     AND value MATCH ?
                 ORDER BY
                     rank(matchinfo(user_directory_search)) DESC,
                     display_name IS NULL,
                     avatar_url IS NULL
                 LIMIT ?
-            """
-            args = (user_id, search_query, limit + 1)
+            """ % (join_clause, where_clause)
+            args = join_args + (search_query, limit + 1)
         else:
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
@@ -723,7 +744,7 @@ def _parse_query_sqlite(search_term):
 
     # Pull out the individual words, discarding any non-word characters.
     results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
-    return " & ".join("(%s* | %s)" % (result, result,) for result in results)
+    return " & ".join("(%s* OR %s)" % (result, result,) for result in results)
 
 
 def _parse_query_postgres(search_term):
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index af65bfe7b8..bf3a66eae4 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -75,6 +75,7 @@ class Cache(object):
         self.cache = LruCache(
             max_size=max_entries, keylen=keylen, cache_type=cache_type,
             size_callback=(lambda d: len(d)) if iterable else None,
+            evicted_callback=self._on_evicted,
         )
 
         self.name = name
@@ -83,6 +84,9 @@ class Cache(object):
         self.thread = None
         self.metrics = register_cache(name, self.cache)
 
+    def _on_evicted(self, evicted_count):
+        self.metrics.inc_evictions(evicted_count)
+
     def check_thread(self):
         expected_thread = self.thread
         if expected_thread is None:
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 6ad53a6390..0aa103eecb 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -79,7 +79,11 @@ class ExpiringCache(object):
         while self._max_len and len(self) > self._max_len:
             _key, value = self._cache.popitem(last=False)
             if self.iterable:
-                self._size_estimate -= len(value.value)
+                removed_len = len(value.value)
+                self.metrics.inc_evictions(removed_len)
+                self._size_estimate -= removed_len
+            else:
+                self.metrics.inc_evictions()
 
     def __getitem__(self, key):
         try:
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index cf5fbb679c..f088dd430e 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -49,7 +49,24 @@ class LruCache(object):
     Can also set callbacks on objects when getting/setting which are fired
     when that key gets invalidated/evicted.
     """
-    def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
+    def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None,
+                 evicted_callback=None):
+        """
+        Args:
+            max_size (int):
+
+            keylen (int):
+
+            cache_type (type):
+                type of underlying cache to be used. Typically one of dict
+                or TreeCache.
+
+            size_callback (func(V) -> int | None):
+
+            evicted_callback (func(int)|None):
+                if not None, called on eviction with the size of the evicted
+                entry
+        """
         cache = cache_type()
         self.cache = cache  # Used for introspection.
         list_root = _Node(None, None, None, None)
@@ -61,8 +78,10 @@ class LruCache(object):
         def evict():
             while cache_len() > max_size:
                 todelete = list_root.prev_node
-                delete_node(todelete)
+                evicted_len = delete_node(todelete)
                 cache.pop(todelete.key, None)
+                if evicted_callback:
+                    evicted_callback(evicted_len)
 
         def synchronized(f):
             @wraps(f)
@@ -111,12 +130,15 @@ class LruCache(object):
             prev_node.next_node = next_node
             next_node.prev_node = prev_node
 
+            deleted_len = 1
             if size_callback:
-                cached_cache_len[0] -= size_callback(node.value)
+                deleted_len = size_callback(node.value)
+                cached_cache_len[0] -= deleted_len
 
             for cb in node.callbacks:
                 cb()
             node.callbacks.clear()
+            return deleted_len
 
         @synchronized
         def cache_get(key, default=None, callbacks=[]):
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
new file mode 100644
index 0000000000..90a2608d6f
--- /dev/null
+++ b/synapse/util/file_consumer.py
@@ -0,0 +1,139 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import threads, reactor
+
+from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+
+import Queue
+
+
+class BackgroundFileConsumer(object):
+    """A consumer that writes to a file like object. Supports both push
+    and pull producers
+
+    Args:
+        file_obj (file): The file like object to write to. Closed when
+            finished.
+    """
+
+    # For PushProducers pause if we have this many unwritten slices
+    _PAUSE_ON_QUEUE_SIZE = 5
+    # And resume once the size of the queue is less than this
+    _RESUME_ON_QUEUE_SIZE = 2
+
+    def __init__(self, file_obj):
+        self._file_obj = file_obj
+
+        # Producer we're registered with
+        self._producer = None
+
+        # True if PushProducer, false if PullProducer
+        self.streaming = False
+
+        # For PushProducers, indicates whether we've paused the producer and
+        # need to call resumeProducing before we get more data.
+        self._paused_producer = False
+
+        # Queue of slices of bytes to be written. When producer calls
+        # unregister a final None is sent.
+        self._bytes_queue = Queue.Queue()
+
+        # Deferred that is resolved when finished writing
+        self._finished_deferred = None
+
+        # If the _writer thread throws an exception it gets stored here.
+        self._write_exception = None
+
+    def registerProducer(self, producer, streaming):
+        """Part of IConsumer interface
+
+        Args:
+            producer (IProducer)
+            streaming (bool): True if push based producer, False if pull
+                based.
+        """
+        if self._producer:
+            raise Exception("registerProducer called twice")
+
+        self._producer = producer
+        self.streaming = streaming
+        self._finished_deferred = preserve_fn(threads.deferToThread)(self._writer)
+        if not streaming:
+            self._producer.resumeProducing()
+
+    def unregisterProducer(self):
+        """Part of IProducer interface
+        """
+        self._producer = None
+        if not self._finished_deferred.called:
+            self._bytes_queue.put_nowait(None)
+
+    def write(self, bytes):
+        """Part of IProducer interface
+        """
+        if self._write_exception:
+            raise self._write_exception
+
+        if self._finished_deferred.called:
+            raise Exception("consumer has closed")
+
+        self._bytes_queue.put_nowait(bytes)
+
+        # If this is a PushProducer and the queue is getting behind
+        # then we pause the producer.
+        if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE:
+            self._paused_producer = True
+            self._producer.pauseProducing()
+
+    def _writer(self):
+        """This is run in a background thread to write to the file.
+        """
+        try:
+            while self._producer or not self._bytes_queue.empty():
+                # If we've paused the producer check if we should resume the
+                # producer.
+                if self._producer and self._paused_producer:
+                    if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
+                        reactor.callFromThread(self._resume_paused_producer)
+
+                bytes = self._bytes_queue.get()
+
+                # If we get a None (or empty list) then that's a signal used
+                # to indicate we should check if we should stop.
+                if bytes:
+                    self._file_obj.write(bytes)
+
+                # If its a pull producer then we need to explicitly ask for
+                # more stuff.
+                if not self.streaming and self._producer:
+                    reactor.callFromThread(self._producer.resumeProducing)
+        except Exception as e:
+            self._write_exception = e
+            raise
+        finally:
+            self._file_obj.close()
+
+    def wait(self):
+        """Returns a deferred that resolves when finished writing to file
+        """
+        return make_deferred_yieldable(self._finished_deferred)
+
+    def _resume_paused_producer(self):
+        """Gets called if we should resume producing after being paused
+        """
+        if self._paused_producer and self._producer:
+            self._paused_producer = False
+            self._producer.resumeProducing()
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 48c9f6802d..94fa7cac98 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -52,13 +52,17 @@ except Exception:
 class LoggingContext(object):
     """Additional context for log formatting. Contexts are scoped within a
     "with" block.
+
     Args:
         name (str): Name for the context for debugging.
     """
 
     __slots__ = [
-        "previous_context", "name", "usage_start", "usage_end", "main_thread",
-        "__dict__", "tag", "alive",
+        "previous_context", "name", "ru_stime", "ru_utime",
+        "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
+        "usage_start", "usage_end",
+        "main_thread", "alive",
+        "request", "tag",
     ]
 
     thread_local = threading.local()
@@ -83,6 +87,9 @@ class LoggingContext(object):
         def add_database_transaction(self, duration_ms):
             pass
 
+        def add_database_scheduled(self, sched_ms):
+            pass
+
         def __nonzero__(self):
             return False
 
@@ -94,9 +101,17 @@ class LoggingContext(object):
         self.ru_stime = 0.
         self.ru_utime = 0.
         self.db_txn_count = 0
-        self.db_txn_duration = 0.
+
+        # ms spent waiting for db txns, excluding scheduling time
+        self.db_txn_duration_ms = 0
+
+        # ms spent waiting for db txns to be scheduled
+        self.db_sched_duration_ms = 0
+
         self.usage_start = None
+        self.usage_end = None
         self.main_thread = threading.current_thread()
+        self.request = None
         self.tag = ""
         self.alive = True
 
@@ -105,7 +120,11 @@ class LoggingContext(object):
 
     @classmethod
     def current_context(cls):
-        """Get the current logging context from thread local storage"""
+        """Get the current logging context from thread local storage
+
+        Returns:
+            LoggingContext: the current logging context
+        """
         return getattr(cls.thread_local, "current_context", cls.sentinel)
 
     @classmethod
@@ -155,11 +174,13 @@ class LoggingContext(object):
         self.alive = False
 
     def copy_to(self, record):
-        """Copy fields from this context to the record"""
-        for key, value in self.__dict__.items():
-            setattr(record, key, value)
+        """Copy logging fields from this context to a log record or
+        another LoggingContext
+        """
 
-        record.ru_utime, record.ru_stime = self.get_resource_usage()
+        # 'request' is the only field we currently use in the logger, so that's
+        # all we need to copy
+        record.request = self.request
 
     def start(self):
         if threading.current_thread() is not self.main_thread:
@@ -194,7 +215,16 @@ class LoggingContext(object):
 
     def add_database_transaction(self, duration_ms):
         self.db_txn_count += 1
-        self.db_txn_duration += duration_ms / 1000.
+        self.db_txn_duration_ms += duration_ms
+
+    def add_database_scheduled(self, sched_ms):
+        """Record a use of the database pool
+
+        Args:
+            sched_ms (int): number of milliseconds it took us to get a
+                connection
+        """
+        self.db_sched_duration_ms += sched_ms
 
 
 class LoggingContextFilter(logging.Filter):
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 4ea930d3e8..e4b5687a4b 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -27,25 +27,62 @@ logger = logging.getLogger(__name__)
 
 metrics = synapse.metrics.get_metrics_for(__name__)
 
-block_timer = metrics.register_distribution(
-    "block_timer",
-    labels=["block_name"]
+# total number of times we have hit this block
+block_counter = metrics.register_counter(
+    "block_count",
+    labels=["block_name"],
+    alternative_names=(
+        # the following are all deprecated aliases for the same metric
+        metrics.name_prefix + x for x in (
+            "_block_timer:count",
+            "_block_ru_utime:count",
+            "_block_ru_stime:count",
+            "_block_db_txn_count:count",
+            "_block_db_txn_duration:count",
+        )
+    )
+)
+
+block_timer = metrics.register_counter(
+    "block_time_seconds",
+    labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_timer:total",
+    ),
 )
 
-block_ru_utime = metrics.register_distribution(
-    "block_ru_utime", labels=["block_name"]
+block_ru_utime = metrics.register_counter(
+    "block_ru_utime_seconds", labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_ru_utime:total",
+    ),
 )
 
-block_ru_stime = metrics.register_distribution(
-    "block_ru_stime", labels=["block_name"]
+block_ru_stime = metrics.register_counter(
+    "block_ru_stime_seconds", labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_ru_stime:total",
+    ),
 )
 
-block_db_txn_count = metrics.register_distribution(
-    "block_db_txn_count", labels=["block_name"]
+block_db_txn_count = metrics.register_counter(
+    "block_db_txn_count", labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_db_txn_count:total",
+    ),
 )
 
-block_db_txn_duration = metrics.register_distribution(
-    "block_db_txn_duration", labels=["block_name"]
+# seconds spent waiting for db txns, excluding scheduling time, in this block
+block_db_txn_duration = metrics.register_counter(
+    "block_db_txn_duration_seconds", labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_db_txn_duration:total",
+    ),
+)
+
+# seconds spent waiting for a db connection, in this block
+block_db_sched_duration = metrics.register_counter(
+    "block_db_sched_duration_seconds", labels=["block_name"],
 )
 
 
@@ -64,7 +101,9 @@ def measure_func(name):
 class Measure(object):
     __slots__ = [
         "clock", "name", "start_context", "start", "new_context", "ru_utime",
-        "ru_stime", "db_txn_count", "db_txn_duration", "created_context"
+        "ru_stime",
+        "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
+        "created_context",
     ]
 
     def __init__(self, clock, name):
@@ -84,13 +123,16 @@ class Measure(object):
 
         self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
         self.db_txn_count = self.start_context.db_txn_count
-        self.db_txn_duration = self.start_context.db_txn_duration
+        self.db_txn_duration_ms = self.start_context.db_txn_duration_ms
+        self.db_sched_duration_ms = self.start_context.db_sched_duration_ms
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         if isinstance(exc_type, Exception) or not self.start_context:
             return
 
         duration = self.clock.time_msec() - self.start
+
+        block_counter.inc(self.name)
         block_timer.inc_by(duration, self.name)
 
         context = LoggingContext.current_context()
@@ -114,7 +156,12 @@ class Measure(object):
             context.db_txn_count - self.db_txn_count, self.name
         )
         block_db_txn_duration.inc_by(
-            context.db_txn_duration - self.db_txn_duration, self.name
+            (context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000.,
+            self.name
+        )
+        block_db_sched_duration.inc_by(
+            (context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000.,
+            self.name
         )
 
         if self.created_context:
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 1adedbb361..47b0bb5eb3 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -26,6 +26,18 @@ logger = logging.getLogger(__name__)
 
 class NotRetryingDestination(Exception):
     def __init__(self, retry_last_ts, retry_interval, destination):
+        """Raised by the limiter (and federation client) to indicate that we are
+        are deliberately not attempting to contact a given server.
+
+        Args:
+            retry_last_ts (int): the unix ts in milliseconds of our last attempt
+                to contact the server.  0 indicates that the last attempt was
+                successful or that we've never actually attempted to connect.
+            retry_interval (int): the time in milliseconds to wait until the next
+                attempt.
+            destination (str): the domain in question
+        """
+
         msg = "Not retrying server %s." % (destination,)
         super(NotRetryingDestination, self).__init__(msg)
 
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
new file mode 100644
index 0000000000..75efa0117b
--- /dev/null
+++ b/synapse/util/threepids.py
@@ -0,0 +1,48 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+logger = logging.getLogger(__name__)
+
+
+def check_3pid_allowed(hs, medium, address):
+    """Checks whether a given format of 3PID is allowed to be used on this HS
+
+    Args:
+        hs (synapse.server.HomeServer): server
+        medium (str): 3pid medium - e.g. email, msisdn
+        address (str): address within that medium (e.g. "wotan@matrix.org")
+            msisdns need to first have been canonicalised
+    Returns:
+        bool: whether the 3PID medium/address is allowed to be added to this HS
+    """
+
+    if hs.config.allowed_local_3pids:
+        for constraint in hs.config.allowed_local_3pids:
+            logger.debug(
+                "Checking 3PID %s (%s) against %s (%s)",
+                address, medium, constraint['pattern'], constraint['medium'],
+            )
+            if (
+                medium == constraint['medium'] and
+                re.match(constraint['pattern'], address)
+            ):
+                return True
+    else:
+        return True
+
+    return False