summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py116
-rw-r--r--synapse/api/constants.py5
-rw-r--r--synapse/api/errors.py132
-rw-r--r--synapse/api/filtering.py300
-rw-r--r--synapse/app/_base.py178
-rw-r--r--synapse/app/appservice.py154
-rw-r--r--synapse/app/client_reader.py136
-rw-r--r--synapse/app/event_creator.py189
-rw-r--r--synapse/app/federation_reader.py129
-rw-r--r--synapse/app/federation_sender.py304
-rw-r--r--synapse/app/frontend_proxy.py227
-rwxr-xr-xsynapse/app/homeserver.py432
-rw-r--r--synapse/app/media_repository.py145
-rw-r--r--synapse/app/pusher.py237
-rw-r--r--synapse/app/synchrotron.py506
-rwxr-xr-xsynapse/app/synctl.py83
-rw-r--r--synapse/app/user_dir.py231
-rw-r--r--synapse/appservice/__init__.py87
-rw-r--r--synapse/appservice/api.py7
-rw-r--r--synapse/appservice/scheduler.py39
-rw-r--r--synapse/config/_base.py58
-rw-r--r--synapse/config/appservice.py13
-rw-r--r--synapse/config/cas.py2
-rw-r--r--synapse/config/emailconfig.py16
-rw-r--r--synapse/config/groups.py32
-rw-r--r--synapse/config/homeserver.py7
-rw-r--r--synapse/config/key.py8
-rw-r--r--synapse/config/logger.py129
-rw-r--r--synapse/config/password_auth_providers.py43
-rw-r--r--synapse/config/push.py61
-rw-r--r--synapse/config/registration.py31
-rw-r--r--synapse/config/repository.py81
-rw-r--r--synapse/config/server.py99
-rw-r--r--synapse/config/spam_checker.py35
-rw-r--r--synapse/config/tls.py18
-rw-r--r--synapse/config/user_directory.py44
-rw-r--r--synapse/config/voip.py8
-rw-r--r--synapse/config/workers.py20
-rw-r--r--synapse/crypto/context_factory.py11
-rw-r--r--synapse/crypto/event_signing.py15
-rw-r--r--synapse/crypto/keyclient.py17
-rw-r--r--synapse/crypto/keyring.py343
-rw-r--r--synapse/event_auth.py12
-rw-r--r--synapse/events/__init__.py18
-rw-r--r--synapse/events/builder.py2
-rw-r--r--synapse/events/snapshot.py132
-rw-r--r--synapse/events/spamcheck.py113
-rw-r--r--synapse/events/utils.py23
-rw-r--r--synapse/federation/__init__.py8
-rw-r--r--synapse/federation/federation_base.py185
-rw-r--r--synapse/federation/federation_client.py167
-rw-r--r--synapse/federation/federation_server.py544
-rw-r--r--synapse/federation/replication.py73
-rw-r--r--synapse/federation/send_queue.py384
-rw-r--r--synapse/federation/transaction_queue.py473
-rw-r--r--synapse/federation/transport/client.py556
-rw-r--r--synapse/federation/transport/server.py661
-rw-r--r--synapse/groups/__init__.py0
-rw-r--r--synapse/groups/attestations.py199
-rw-r--r--synapse/groups/groups_server.py950
-rw-r--r--synapse/handlers/__init__.py4
-rw-r--r--synapse/handlers/_base.py36
-rw-r--r--synapse/handlers/appservice.py71
-rw-r--r--synapse/handlers/auth.py490
-rw-r--r--synapse/handlers/deactivate_account.py52
-rw-r--r--synapse/handlers/device.py351
-rw-r--r--synapse/handlers/devicemessage.py16
-rw-r--r--synapse/handlers/directory.py27
-rw-r--r--synapse/handlers/e2e_keys.py203
-rw-r--r--synapse/handlers/federation.py762
-rw-r--r--synapse/handlers/groups_local.py471
-rw-r--r--synapse/handlers/identity.py56
-rw-r--r--synapse/handlers/initial_sync.py27
-rw-r--r--synapse/handlers/message.py696
-rw-r--r--synapse/handlers/presence.py342
-rw-r--r--synapse/handlers/profile.py167
-rw-r--r--synapse/handlers/read_marker.py64
-rw-r--r--synapse/handlers/receipts.py72
-rw-r--r--synapse/handlers/register.py123
-rw-r--r--synapse/handlers/room.py76
-rw-r--r--synapse/handlers/room_list.py163
-rw-r--r--synapse/handlers/room_member.py405
-rw-r--r--synapse/handlers/room_member_worker.py102
-rw-r--r--synapse/handlers/search.py2
-rw-r--r--synapse/handlers/set_password.py56
-rw-r--r--synapse/handlers/sync.py407
-rw-r--r--synapse/handlers/typing.py69
-rw-r--r--synapse/handlers/user_directory.py681
-rw-r--r--synapse/http/__init__.py22
-rw-r--r--synapse/http/additional_resource.py55
-rw-r--r--synapse/http/client.py197
-rw-r--r--synapse/http/endpoint.py50
-rw-r--r--synapse/http/matrixfederationclient.py452
-rw-r--r--synapse/http/server.py238
-rw-r--r--synapse/http/servlet.py37
-rw-r--r--synapse/http/site.py18
-rw-r--r--synapse/metrics/__init__.py66
-rw-r--r--synapse/metrics/metric.py168
-rw-r--r--synapse/module_api/__init__.py123
-rw-r--r--synapse/notifier.py145
-rw-r--r--synapse/push/action_generator.py16
-rw-r--r--synapse/push/baserules.py23
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py457
-rw-r--r--synapse/push/emailpusher.py31
-rw-r--r--synapse/push/httppusher.py75
-rw-r--r--synapse/push/mailer.py95
-rw-r--r--synapse/push/push_rule_evaluator.py154
-rw-r--r--synapse/push/push_tools.py15
-rw-r--r--synapse/push/pusher.py60
-rw-r--r--synapse/push/pusherpool.py63
-rw-r--r--synapse/python_dependencies.py34
-rw-r--r--synapse/replication/expire_cache.py60
-rw-r--r--synapse/replication/http/__init__.py30
-rw-r--r--synapse/replication/http/membership.py334
-rw-r--r--synapse/replication/http/send_event.py160
-rw-r--r--synapse/replication/presence_resource.py59
-rw-r--r--synapse/replication/pusher_resource.py54
-rw-r--r--synapse/replication/resource.py576
-rw-r--r--synapse/replication/slave/storage/_base.py37
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py5
-rw-r--r--synapse/replication/slave/storage/account_data.py94
-rw-r--r--synapse/replication/slave/storage/appservice.py30
-rw-r--r--synapse/replication/slave/storage/client_ips.py47
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py31
-rw-r--r--synapse/replication/slave/storage/devices.py26
-rw-r--r--synapse/replication/slave/storage/directory.py8
-rw-r--r--synapse/replication/slave/storage/events.py263
-rw-r--r--synapse/replication/slave/storage/groups.py54
-rw-r--r--synapse/replication/slave/storage/presence.py28
-rw-r--r--synapse/replication/slave/storage/profile.py21
-rw-r--r--synapse/replication/slave/storage/push_rule.py47
-rw-r--r--synapse/replication/slave/storage/pushers.py28
-rw-r--r--synapse/replication/slave/storage/receipts.py60
-rw-r--r--synapse/replication/slave/storage/registration.py18
-rw-r--r--synapse/replication/slave/storage/room.py32
-rw-r--r--synapse/replication/tcp/__init__.py30
-rw-r--r--synapse/replication/tcp/client.py203
-rw-r--r--synapse/replication/tcp/commands.py386
-rw-r--r--synapse/replication/tcp/protocol.py655
-rw-r--r--synapse/replication/tcp/resource.py307
-rw-r--r--synapse/replication/tcp/streams.py506
-rw-r--r--synapse/rest/__init__.py6
-rw-r--r--synapse/rest/client/v1/admin.py304
-rw-r--r--synapse/rest/client/v1/base.py6
-rw-r--r--synapse/rest/client/v1/directory.py8
-rw-r--r--synapse/rest/client/v1/login.py170
-rw-r--r--synapse/rest/client/v1/logout.py36
-rw-r--r--synapse/rest/client/v1/presence.py5
-rw-r--r--synapse/rest/client/v1/profile.py22
-rw-r--r--synapse/rest/client/v1/pusher.py16
-rw-r--r--synapse/rest/client/v1/register.py56
-rw-r--r--synapse/rest/client/v1/room.py123
-rw-r--r--synapse/rest/client/v1/voip.py5
-rw-r--r--synapse/rest/client/v2_alpha/_base.py51
-rw-r--r--synapse/rest/client/v2_alpha/account.py274
-rw-r--r--synapse/rest/client/v2_alpha/account_data.py9
-rw-r--r--synapse/rest/client/v2_alpha/devices.py71
-rw-r--r--synapse/rest/client/v2_alpha/filter.py8
-rw-r--r--synapse/rest/client/v2_alpha/groups.py786
-rw-r--r--synapse/rest/client/v2_alpha/keys.py24
-rw-r--r--synapse/rest/client/v2_alpha/notifications.py2
-rw-r--r--synapse/rest/client/v2_alpha/read_marker.py66
-rw-r--r--synapse/rest/client/v2_alpha/register.py315
-rw-r--r--synapse/rest/client/v2_alpha/sendtodevice.py2
-rw-r--r--synapse/rest/client/v2_alpha/sync.py81
-rw-r--r--synapse/rest/client/v2_alpha/thirdparty.py19
-rw-r--r--synapse/rest/client/v2_alpha/user_directory.py79
-rw-r--r--synapse/rest/client/versions.py1
-rw-r--r--synapse/rest/key/v2/local_key_resource.py5
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py10
-rw-r--r--synapse/rest/media/v1/_base.py154
-rw-r--r--synapse/rest/media/v1/download_resource.py60
-rw-r--r--synapse/rest/media/v1/filepath.py170
-rw-r--r--synapse/rest/media/v1/media_repository.py733
-rw-r--r--synapse/rest/media/v1/media_storage.py263
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py281
-rw-r--r--synapse/rest/media/v1/storage_provider.py145
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py192
-rw-r--r--synapse/rest/media/v1/thumbnailer.py25
-rw-r--r--synapse/rest/media/v1/upload_resource.py8
-rw-r--r--synapse/server.py164
-rw-r--r--synapse/server.pyi29
-rw-r--r--synapse/state.py480
-rw-r--r--synapse/storage/__init__.py204
-rw-r--r--synapse/storage/_base.py327
-rw-r--r--synapse/storage/account_data.py191
-rw-r--r--synapse/storage/appservice.py135
-rw-r--r--synapse/storage/background_updates.py173
-rw-r--r--synapse/storage/client_ips.py153
-rw-r--r--synapse/storage/deviceinbox.py79
-rw-r--r--synapse/storage/devices.py249
-rw-r--r--synapse/storage/directory.py60
-rw-r--r--synapse/storage/end_to_end_keys.py108
-rw-r--r--synapse/storage/engines/__init__.py5
-rw-r--r--synapse/storage/engines/postgres.py6
-rw-r--r--synapse/storage/engines/sqlite3.py19
-rw-r--r--synapse/storage/event_federation.py407
-rw-r--r--synapse/storage/event_push_actions.py695
-rw-r--r--synapse/storage/events.py1593
-rw-r--r--synapse/storage/events_worker.py416
-rw-r--r--synapse/storage/filtering.py12
-rw-r--r--synapse/storage/group_server.py1253
-rw-r--r--synapse/storage/keys.py46
-rw-r--r--synapse/storage/media_repository.py127
-rw-r--r--synapse/storage/prepare_database.py77
-rw-r--r--synapse/storage/presence.py4
-rw-r--r--synapse/storage/profile.py155
-rw-r--r--synapse/storage/push_rule.py97
-rw-r--r--synapse/storage/pusher.py135
-rw-r--r--synapse/storage/receipts.py126
-rw-r--r--synapse/storage/registration.py191
-rw-r--r--synapse/storage/room.py456
-rw-r--r--synapse/storage/roommember.py714
-rw-r--r--synapse/storage/schema/delta/14/upgrade_appservice_db.py3
-rw-r--r--synapse/storage/schema/delta/25/fts.py4
-rw-r--r--synapse/storage/schema/delta/27/ts.py4
-rw-r--r--synapse/storage/schema/delta/30/as_users.py6
-rw-r--r--synapse/storage/schema/delta/31/search_update.py4
-rw-r--r--synapse/storage/schema/delta/33/event_fields.py4
-rw-r--r--synapse/storage/schema/delta/37/remove_auth_idx.py4
-rw-r--r--synapse/storage/schema/delta/38/postgres_fts_gist.sql6
-rw-r--r--synapse/storage/schema/delta/40/event_push_summary.sql37
-rw-r--r--synapse/storage/schema/delta/40/pushers.sql39
-rw-r--r--synapse/storage/schema/delta/41/device_list_stream_idx.sql (renamed from synapse/storage/schema/delta/23/refresh_tokens.sql)10
-rw-r--r--synapse/storage/schema/delta/41/device_outbound_index.sql16
-rw-r--r--synapse/storage/schema/delta/41/event_search_event_id_idx.sql17
-rw-r--r--synapse/storage/schema/delta/41/ratelimit.sql22
-rw-r--r--synapse/storage/schema/delta/42/current_state_delta.sql26
-rw-r--r--synapse/storage/schema/delta/42/device_list_last_id.sql33
-rw-r--r--synapse/storage/schema/delta/42/event_auth_state_only.sql (renamed from synapse/storage/schema/delta/33/refreshtoken_device_index.sql)4
-rw-r--r--synapse/storage/schema/delta/42/user_dir.py84
-rw-r--r--synapse/storage/schema/delta/43/blocked_rooms.sql21
-rw-r--r--synapse/storage/schema/delta/43/quarantine_media.sql17
-rw-r--r--synapse/storage/schema/delta/43/url_cache.sql (renamed from synapse/storage/schema/delta/33/refreshtoken_device.sql)4
-rw-r--r--synapse/storage/schema/delta/43/user_share.sql33
-rw-r--r--synapse/storage/schema/delta/44/expire_url_cache.sql41
-rw-r--r--synapse/storage/schema/delta/45/group_server.sql167
-rw-r--r--synapse/storage/schema/delta/45/profile_cache.sql28
-rw-r--r--synapse/storage/schema/delta/46/drop_refresh_tokens.sql17
-rw-r--r--synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql35
-rw-r--r--synapse/storage/schema/delta/46/group_server.sql32
-rw-r--r--synapse/storage/schema/delta/46/local_media_repository_url_idx.sql24
-rw-r--r--synapse/storage/schema/delta/46/user_dir_null_room_ids.sql35
-rw-r--r--synapse/storage/schema/delta/46/user_dir_typos.sql24
-rw-r--r--synapse/storage/schema/delta/47/last_access_media.sql16
-rw-r--r--synapse/storage/schema/delta/47/postgres_fts_gin.sql17
-rw-r--r--synapse/storage/schema/delta/47/push_actions_staging.sql28
-rw-r--r--synapse/storage/schema/delta/47/state_group_seq.py37
-rw-r--r--synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql17
-rw-r--r--synapse/storage/schema/delta/48/group_unique_indexes.py57
-rw-r--r--synapse/storage/schema/delta/48/groups_joinable.sql22
-rw-r--r--synapse/storage/schema/schema_version.sql7
-rw-r--r--synapse/storage/search.py203
-rw-r--r--synapse/storage/signatures.py14
-rw-r--r--synapse/storage/state.py635
-rw-r--r--synapse/storage/stream.py326
-rw-r--r--synapse/storage/tags.py26
-rw-r--r--synapse/storage/transactions.py6
-rw-r--r--synapse/storage/user_directory.py764
-rw-r--r--synapse/storage/util/id_generators.py14
-rw-r--r--synapse/streams/config.py6
-rw-r--r--synapse/streams/events.py4
-rw-r--r--synapse/types.py134
-rw-r--r--synapse/util/__init__.py62
-rw-r--r--synapse/util/async.py151
-rw-r--r--synapse/util/caches/__init__.py33
-rw-r--r--synapse/util/caches/descriptors.py412
-rw-r--r--synapse/util/caches/dictionary_cache.py61
-rw-r--r--synapse/util/caches/expiringcache.py16
-rw-r--r--synapse/util/caches/lrucache.py43
-rw-r--r--synapse/util/caches/response_cache.py106
-rw-r--r--synapse/util/caches/stream_change_cache.py23
-rw-r--r--synapse/util/distributor.py22
-rw-r--r--synapse/util/file_consumer.py141
-rw-r--r--synapse/util/frozenutils.py19
-rw-r--r--synapse/util/httpresourcetree.py11
-rw-r--r--synapse/util/logcontext.py199
-rw-r--r--synapse/util/logformatter.py51
-rw-r--r--synapse/util/metrics.py75
-rw-r--r--synapse/util/module_loader.py42
-rw-r--r--synapse/util/msisdn.py40
-rw-r--r--synapse/util/ratelimitutils.py4
-rw-r--r--synapse/util/retryutils.py50
-rw-r--r--synapse/util/stringutils.py19
-rw-r--r--synapse/util/threepids.py48
-rw-r--r--synapse/util/wheel_timer.py9
-rw-r--r--synapse/visibility.py47
288 files changed, 30549 insertions, 9750 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index da8ef90a77..f31cb9a3cb 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
 """ This is a reference implementation of a Matrix home server.
 """
 
-__version__ = "0.19.1"
+__version__ = "0.28.1"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 03a215ab1b..f17fda6315 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -23,7 +23,8 @@ from synapse import event_auth
 from synapse.api.constants import EventTypes, Membership, JoinRules
 from synapse.api.errors import AuthError, Codes
 from synapse.types import UserID
-from synapse.util.logcontext import preserve_context_over_fn
+from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR
+from synapse.util.caches.lrucache import LruCache
 from synapse.util.metrics import Measure
 
 logger = logging.getLogger(__name__)
@@ -39,6 +40,10 @@ AuthEventTypes = (
 GUEST_DEVICE_ID = "guest_device"
 
 
+class _InvalidMacaroonException(Exception):
+    pass
+
+
 class Auth(object):
     """
     FIXME: This class contains a mix of functions for authenticating users
@@ -51,6 +56,9 @@ class Auth(object):
         self.state = hs.get_state_handler()
         self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
 
+        self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
+        register_cache("token_cache", self.token_cache)
+
     @defer.inlineCallbacks
     def check_from_context(self, event, context, do_sig_check=True):
         auth_events_ids = yield self.compute_auth_events(
@@ -144,17 +152,8 @@ class Auth(object):
     @defer.inlineCallbacks
     def check_host_in_room(self, room_id, host):
         with Measure(self.clock, "check_host_in_room"):
-            latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
-
-            logger.debug("calling resolve_state_groups from check_host_in_room")
-            entry = yield self.state.resolve_state_groups(
-                room_id, latest_event_ids
-            )
-
-            ret = yield self.store.is_host_joined(
-                room_id, host, entry.state_group, entry.state
-            )
-            defer.returnValue(ret)
+            latest_event_ids = yield self.store.is_host_joined(room_id, host)
+            defer.returnValue(latest_event_ids)
 
     def _check_joined_room(self, member, user_id, room_id):
         if not member or member.membership != Membership.JOIN:
@@ -205,13 +204,12 @@ class Auth(object):
 
             ip_addr = self.hs.get_ip_from_request(request)
             user_agent = request.requestHeaders.getRawHeaders(
-                "User-Agent",
-                default=[""]
+                b"User-Agent",
+                default=[b""]
             )[0]
             if user and access_token and ip_addr:
-                preserve_context_over_fn(
-                    self.store.insert_client_ip,
-                    user=user,
+                self.store.insert_client_ip(
+                    user_id=user.to_string(),
                     access_token=access_token,
                     ip=ip_addr,
                     user_agent=user_agent,
@@ -272,13 +270,17 @@ class Auth(object):
             rights (str): The operation being performed; the access token must
                 allow this.
         Returns:
-            dict : dict that includes the user and the ID of their access token.
+            Deferred[dict]: dict that includes:
+               `user` (UserID)
+               `is_guest` (bool)
+               `token_id` (int|None): access token id. May be None if guest
+               `device_id` (str|None): device corresponding to access token
         Raises:
             AuthError if no user by that token exists or the token is invalid.
         """
         try:
-            macaroon = pymacaroons.Macaroon.deserialize(token)
-        except Exception:  # deserialize can throw more-or-less anything
+            user_id, guest = self._parse_and_validate_macaroon(token, rights)
+        except _InvalidMacaroonException:
             # doesn't look like a macaroon: treat it as an opaque token which
             # must be in the database.
             # TODO: it would be nice to get rid of this, but apparently some
@@ -287,19 +289,8 @@ class Auth(object):
             defer.returnValue(r)
 
         try:
-            user_id = self.get_user_id_from_macaroon(macaroon)
             user = UserID.from_string(user_id)
 
-            self.validate_macaroon(
-                macaroon, rights, self.hs.config.expire_access_token,
-                user_id=user_id,
-            )
-
-            guest = False
-            for caveat in macaroon.caveats:
-                if caveat.caveat_id == "guest = true":
-                    guest = True
-
             if guest:
                 # Guest access tokens are not stored in the database (there can
                 # only be one access token per guest, anyway).
@@ -371,6 +362,55 @@ class Auth(object):
                 errcode=Codes.UNKNOWN_TOKEN
             )
 
+    def _parse_and_validate_macaroon(self, token, rights="access"):
+        """Takes a macaroon and tries to parse and validate it. This is cached
+        if and only if rights == access and there isn't an expiry.
+
+        On invalid macaroon raises _InvalidMacaroonException
+
+        Returns:
+            (user_id, is_guest)
+        """
+        if rights == "access":
+            cached = self.token_cache.get(token, None)
+            if cached:
+                return cached
+
+        try:
+            macaroon = pymacaroons.Macaroon.deserialize(token)
+        except Exception:  # deserialize can throw more-or-less anything
+            # doesn't look like a macaroon: treat it as an opaque token which
+            # must be in the database.
+            # TODO: it would be nice to get rid of this, but apparently some
+            # people use access tokens which aren't macaroons
+            raise _InvalidMacaroonException()
+
+        try:
+            user_id = self.get_user_id_from_macaroon(macaroon)
+
+            has_expiry = False
+            guest = False
+            for caveat in macaroon.caveats:
+                if caveat.caveat_id.startswith("time "):
+                    has_expiry = True
+                elif caveat.caveat_id == "guest = true":
+                    guest = True
+
+            self.validate_macaroon(
+                macaroon, rights, self.hs.config.expire_access_token,
+                user_id=user_id,
+            )
+        except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
+            raise AuthError(
+                self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
+                errcode=Codes.UNKNOWN_TOKEN
+            )
+
+        if not has_expiry and rights == "access":
+            self.token_cache[token] = (user_id, guest)
+
+        return user_id, guest
+
     def get_user_id_from_macaroon(self, macaroon):
         """Retrieve the user_id given by the caveats on the macaroon.
 
@@ -483,6 +523,14 @@ class Auth(object):
             )
 
     def is_server_admin(self, user):
+        """ Check if the given user is a local server admin.
+
+        Args:
+            user (str): mxid of user to check
+
+        Returns:
+            bool: True if the user is an admin
+        """
         return self.store.is_server_admin(user)
 
     @defer.inlineCallbacks
@@ -624,7 +672,7 @@ def has_access_token(request):
         bool: False if no access_token was given, True otherwise.
     """
     query_params = request.args.get("access_token")
-    auth_headers = request.requestHeaders.getRawHeaders("Authorization")
+    auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
     return bool(query_params) or bool(auth_headers)
 
 
@@ -644,8 +692,8 @@ def get_access_token_from_request(request, token_not_found_http_status=401):
         AuthError: If there isn't an access_token in the request.
     """
 
-    auth_headers = request.requestHeaders.getRawHeaders("Authorization")
-    query_params = request.args.get("access_token")
+    auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+    query_params = request.args.get(b"access_token")
     if auth_headers:
         # Try the get the access_token from a "Authorization: Bearer"
         # header
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index ca23c9c460..5baba43966 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,6 +16,9 @@
 
 """Contains constants from the specification."""
 
+# the "depth" field on events is limited to 2**63 - 1
+MAX_DEPTH = 2**63 - 1
+
 
 class Membership(object):
 
@@ -44,6 +48,7 @@ class JoinRules(object):
 class LoginType(object):
     PASSWORD = u"m.login.password"
     EMAIL_IDENTITY = u"m.login.email.identity"
+    MSISDN = u"m.login.msisdn"
     RECAPTCHA = u"m.login.recaptcha"
     DUMMY = u"m.login.dummy"
 
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 921c457738..a9ff5576f3 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -17,6 +17,9 @@
 
 import logging
 
+import simplejson as json
+from six import iteritems
+
 logger = logging.getLogger(__name__)
 
 
@@ -45,32 +48,52 @@ class Codes(object):
     THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
     THREEPID_IN_USE = "M_THREEPID_IN_USE"
     THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
+    THREEPID_DENIED = "M_THREEPID_DENIED"
     INVALID_USERNAME = "M_INVALID_USERNAME"
     SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
 
 
 class CodeMessageException(RuntimeError):
-    """An exception with integer code and message string attributes."""
+    """An exception with integer code and message string attributes.
 
+    Attributes:
+        code (int): HTTP error code
+        msg (str): string describing the error
+    """
     def __init__(self, code, msg):
         super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
         self.code = code
         self.msg = msg
-        self.response_code_message = None
 
     def error_dict(self):
         return cs_error(self.msg)
 
 
+class MatrixCodeMessageException(CodeMessageException):
+    """An error from a general matrix endpoint, eg. from a proxied Matrix API call.
+
+    Attributes:
+        errcode (str): Matrix error code e.g 'M_FORBIDDEN'
+    """
+    def __init__(self, code, msg, errcode=Codes.UNKNOWN):
+        super(MatrixCodeMessageException, self).__init__(code, msg)
+        self.errcode = errcode
+
+
 class SynapseError(CodeMessageException):
-    """A base error which can be caught for all synapse events."""
+    """A base exception type for matrix errors which have an errcode and error
+    message (as well as an HTTP status code).
+
+    Attributes:
+        errcode (str): Matrix error code e.g 'M_FORBIDDEN'
+    """
     def __init__(self, code, msg, errcode=Codes.UNKNOWN):
         """Constructs a synapse error.
 
         Args:
             code (int): The integer error code (an HTTP response code)
             msg (str): The human-readable error message.
-            err (str): The error code e.g 'M_FORBIDDEN'
+            errcode (str): The matrix error code e.g 'M_FORBIDDEN'
         """
         super(SynapseError, self).__init__(code, msg)
         self.errcode = errcode
@@ -81,12 +104,87 @@ class SynapseError(CodeMessageException):
             self.errcode,
         )
 
+    @classmethod
+    def from_http_response_exception(cls, err):
+        """Make a SynapseError based on an HTTPResponseException
+
+        This is useful when a proxied request has failed, and we need to
+        decide how to map the failure onto a matrix error to send back to the
+        client.
+
+        An attempt is made to parse the body of the http response as a matrix
+        error. If that succeeds, the errcode and error message from the body
+        are used as the errcode and error message in the new synapse error.
+
+        Otherwise, the errcode is set to M_UNKNOWN, and the error message is
+        set to the reason code from the HTTP response.
+
+        Args:
+            err (HttpResponseException):
+
+        Returns:
+            SynapseError:
+        """
+        # try to parse the body as json, to get better errcode/msg, but
+        # default to M_UNKNOWN with the HTTP status as the error text
+        try:
+            j = json.loads(err.response)
+        except ValueError:
+            j = {}
+        errcode = j.get('errcode', Codes.UNKNOWN)
+        errmsg = j.get('error', err.msg)
+
+        res = SynapseError(err.code, errmsg, errcode)
+        return res
+
 
 class RegistrationError(SynapseError):
     """An error raised when a registration event fails."""
     pass
 
 
+class FederationDeniedError(SynapseError):
+    """An error raised when the server tries to federate with a server which
+    is not on its federation whitelist.
+
+    Attributes:
+        destination (str): The destination which has been denied
+    """
+
+    def __init__(self, destination):
+        """Raised by federation client or server to indicate that we are
+        are deliberately not attempting to contact a given server because it is
+        not on our federation whitelist.
+
+        Args:
+            destination (str): the domain in question
+        """
+
+        self.destination = destination
+
+        super(FederationDeniedError, self).__init__(
+            code=403,
+            msg="Federation denied with %s." % (self.destination,),
+            errcode=Codes.FORBIDDEN,
+        )
+
+
+class InteractiveAuthIncompleteError(Exception):
+    """An error raised when UI auth is not yet complete
+
+    (This indicates we should return a 401 with 'result' as the body)
+
+    Attributes:
+        result (dict): the server response to the request, which should be
+            passed back to the client
+    """
+    def __init__(self, result):
+        super(InteractiveAuthIncompleteError, self).__init__(
+            "Interactive auth not yet complete",
+        )
+        self.result = result
+
+
 class UnrecognizedRequestError(SynapseError):
     """An error indicating we don't understand the request you're trying to make"""
     def __init__(self, *args, **kwargs):
@@ -106,13 +204,11 @@ class UnrecognizedRequestError(SynapseError):
 
 class NotFoundError(SynapseError):
     """An error indicating we can't find the thing you asked for"""
-    def __init__(self, *args, **kwargs):
-        if "errcode" not in kwargs:
-            kwargs["errcode"] = Codes.NOT_FOUND
+    def __init__(self, msg="Not found", errcode=Codes.NOT_FOUND):
         super(NotFoundError, self).__init__(
             404,
-            "Not found",
-            **kwargs
+            msg,
+            errcode=errcode
         )
 
 
@@ -173,7 +269,6 @@ class LimitExceededError(SynapseError):
                  errcode=Codes.LIMIT_EXCEEDED):
         super(LimitExceededError, self).__init__(code, msg, errcode)
         self.retry_after_ms = retry_after_ms
-        self.response_code_message = "Too Many Requests"
 
     def error_dict(self):
         return cs_error(
@@ -203,7 +298,7 @@ def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
         A dict representing the error response JSON.
     """
     err = {"error": msg, "errcode": code}
-    for key, value in kwargs.iteritems():
+    for key, value in iteritems(kwargs):
         err[key] = value
     return err
 
@@ -243,6 +338,19 @@ class FederationError(RuntimeError):
 
 
 class HttpResponseException(CodeMessageException):
+    """
+    Represents an HTTP-level failure of an outbound request
+
+    Attributes:
+        response (str): body of response
+    """
     def __init__(self, code, msg, response):
-        self.response = response
+        """
+
+        Args:
+            code (int): HTTP status code
+            msg (str): reason phrase from HTTP response status line
+            response (str): body of response
+        """
         super(HttpResponseException, self).__init__(code, msg)
+        self.response = response
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index fb291d7fb9..db43219d24 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -13,11 +13,174 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from synapse.api.errors import SynapseError
+from synapse.storage.presence import UserPresenceState
 from synapse.types import UserID, RoomID
-
 from twisted.internet import defer
 
-import ujson as json
+import simplejson as json
+import jsonschema
+from jsonschema import FormatChecker
+
+FILTER_SCHEMA = {
+    "additionalProperties": False,
+    "type": "object",
+    "properties": {
+        "limit": {
+            "type": "number"
+        },
+        "senders": {
+            "$ref": "#/definitions/user_id_array"
+        },
+        "not_senders": {
+            "$ref": "#/definitions/user_id_array"
+        },
+        # TODO: We don't limit event type values but we probably should...
+        # check types are valid event types
+        "types": {
+            "type": "array",
+            "items": {
+                "type": "string"
+            }
+        },
+        "not_types": {
+            "type": "array",
+            "items": {
+                "type": "string"
+            }
+        }
+    }
+}
+
+ROOM_FILTER_SCHEMA = {
+    "additionalProperties": False,
+    "type": "object",
+    "properties": {
+        "not_rooms": {
+            "$ref": "#/definitions/room_id_array"
+        },
+        "rooms": {
+            "$ref": "#/definitions/room_id_array"
+        },
+        "ephemeral": {
+            "$ref": "#/definitions/room_event_filter"
+        },
+        "include_leave": {
+            "type": "boolean"
+        },
+        "state": {
+            "$ref": "#/definitions/room_event_filter"
+        },
+        "timeline": {
+            "$ref": "#/definitions/room_event_filter"
+        },
+        "account_data": {
+            "$ref": "#/definitions/room_event_filter"
+        },
+    }
+}
+
+ROOM_EVENT_FILTER_SCHEMA = {
+    "additionalProperties": False,
+    "type": "object",
+    "properties": {
+        "limit": {
+            "type": "number"
+        },
+        "senders": {
+            "$ref": "#/definitions/user_id_array"
+        },
+        "not_senders": {
+            "$ref": "#/definitions/user_id_array"
+        },
+        "types": {
+            "type": "array",
+            "items": {
+                "type": "string"
+            }
+        },
+        "not_types": {
+            "type": "array",
+            "items": {
+                "type": "string"
+            }
+        },
+        "rooms": {
+            "$ref": "#/definitions/room_id_array"
+        },
+        "not_rooms": {
+            "$ref": "#/definitions/room_id_array"
+        },
+        "contains_url": {
+            "type": "boolean"
+        }
+    }
+}
+
+USER_ID_ARRAY_SCHEMA = {
+    "type": "array",
+    "items": {
+        "type": "string",
+        "format": "matrix_user_id"
+    }
+}
+
+ROOM_ID_ARRAY_SCHEMA = {
+    "type": "array",
+    "items": {
+        "type": "string",
+        "format": "matrix_room_id"
+    }
+}
+
+USER_FILTER_SCHEMA = {
+    "$schema": "http://json-schema.org/draft-04/schema#",
+    "description": "schema for a Sync filter",
+    "type": "object",
+    "definitions": {
+        "room_id_array": ROOM_ID_ARRAY_SCHEMA,
+        "user_id_array": USER_ID_ARRAY_SCHEMA,
+        "filter": FILTER_SCHEMA,
+        "room_filter": ROOM_FILTER_SCHEMA,
+        "room_event_filter": ROOM_EVENT_FILTER_SCHEMA
+    },
+    "properties": {
+        "presence": {
+            "$ref": "#/definitions/filter"
+        },
+        "account_data": {
+            "$ref": "#/definitions/filter"
+        },
+        "room": {
+            "$ref": "#/definitions/room_filter"
+        },
+        "event_format": {
+            "type": "string",
+            "enum": ["client", "federation"]
+        },
+        "event_fields": {
+            "type": "array",
+            "items": {
+                "type": "string",
+                # Don't allow '\\' in event field filters. This makes matching
+                # events a lot easier as we can then use a negative lookbehind
+                # assertion to split '\.' If we allowed \\ then it would
+                # incorrectly split '\\.' See synapse.events.utils.serialize_event
+                "pattern": "^((?!\\\).)*$"
+            }
+        }
+    },
+    "additionalProperties": False
+}
+
+
+@FormatChecker.cls_checks('matrix_room_id')
+def matrix_room_id_validator(room_id_str):
+    return RoomID.from_string(room_id_str)
+
+
+@FormatChecker.cls_checks('matrix_user_id')
+def matrix_user_id_validator(user_id_str):
+    return UserID.from_string(user_id_str)
 
 
 class Filtering(object):
@@ -52,98 +215,11 @@ class Filtering(object):
         # NB: Filters are the complete json blobs. "Definitions" are an
         # individual top-level key e.g. public_user_data. Filters are made of
         # many definitions.
-
-        top_level_definitions = [
-            "presence", "account_data"
-        ]
-
-        room_level_definitions = [
-            "state", "timeline", "ephemeral", "account_data"
-        ]
-
-        for key in top_level_definitions:
-            if key in user_filter_json:
-                self._check_definition(user_filter_json[key])
-
-        if "room" in user_filter_json:
-            self._check_definition_room_lists(user_filter_json["room"])
-            for key in room_level_definitions:
-                if key in user_filter_json["room"]:
-                    self._check_definition(user_filter_json["room"][key])
-
-        if "event_fields" in user_filter_json:
-            if type(user_filter_json["event_fields"]) != list:
-                raise SynapseError(400, "event_fields must be a list of strings")
-            for field in user_filter_json["event_fields"]:
-                if not isinstance(field, basestring):
-                    raise SynapseError(400, "Event field must be a string")
-                # Don't allow '\\' in event field filters. This makes matching
-                # events a lot easier as we can then use a negative lookbehind
-                # assertion to split '\.' If we allowed \\ then it would
-                # incorrectly split '\\.' See synapse.events.utils.serialize_event
-                if r'\\' in field:
-                    raise SynapseError(
-                        400, r'The escape character \ cannot itself be escaped'
-                    )
-
-    def _check_definition_room_lists(self, definition):
-        """Check that "rooms" and "not_rooms" are lists of room ids if they
-        are present
-
-        Args:
-            definition(dict): The filter definition
-        Raises:
-            SynapseError: If there was a problem with this definition.
-        """
-        # check rooms are valid room IDs
-        room_id_keys = ["rooms", "not_rooms"]
-        for key in room_id_keys:
-            if key in definition:
-                if type(definition[key]) != list:
-                    raise SynapseError(400, "Expected %s to be a list." % key)
-                for room_id in definition[key]:
-                    RoomID.from_string(room_id)
-
-    def _check_definition(self, definition):
-        """Check if the provided definition is valid.
-
-        This inspects not only the types but also the values to make sure they
-        make sense.
-
-        Args:
-            definition(dict): The filter definition
-        Raises:
-            SynapseError: If there was a problem with this definition.
-        """
-        # NB: Filters are the complete json blobs. "Definitions" are an
-        # individual top-level key e.g. public_user_data. Filters are made of
-        # many definitions.
-        if type(definition) != dict:
-            raise SynapseError(
-                400, "Expected JSON object, not %s" % (definition,)
-            )
-
-        self._check_definition_room_lists(definition)
-
-        # check senders are valid user IDs
-        user_id_keys = ["senders", "not_senders"]
-        for key in user_id_keys:
-            if key in definition:
-                if type(definition[key]) != list:
-                    raise SynapseError(400, "Expected %s to be a list." % key)
-                for user_id in definition[key]:
-                    UserID.from_string(user_id)
-
-        # TODO: We don't limit event type values but we probably should...
-        # check types are valid event types
-        event_keys = ["types", "not_types"]
-        for key in event_keys:
-            if key in definition:
-                if type(definition[key]) != list:
-                    raise SynapseError(400, "Expected %s to be a list." % key)
-                for event_type in definition[key]:
-                    if not isinstance(event_type, basestring):
-                        raise SynapseError(400, "Event type should be a string")
+        try:
+            jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA,
+                                format_checker=FormatChecker())
+        except jsonschema.ValidationError as e:
+            raise SynapseError(400, e.message)
 
 
 class FilterCollection(object):
@@ -253,19 +329,35 @@ class Filter(object):
         Returns:
             bool: True if the event matches
         """
-        sender = event.get("sender", None)
-        if not sender:
-            # Presence events have their 'sender' in content.user_id
-            content = event.get("content")
-            # account_data has been allowed to have non-dict content, so check type first
-            if isinstance(content, dict):
-                sender = content.get("user_id")
+        # We usually get the full "events" as dictionaries coming through,
+        # except for presence which actually gets passed around as its own
+        # namedtuple type.
+        if isinstance(event, UserPresenceState):
+            sender = event.user_id
+            room_id = None
+            ev_type = "m.presence"
+            is_url = False
+        else:
+            sender = event.get("sender", None)
+            if not sender:
+                # Presence events had their 'sender' in content.user_id, but are
+                # now handled above. We don't know if anything else uses this
+                # form. TODO: Check this and probably remove it.
+                content = event.get("content")
+                # account_data has been allowed to have non-dict content, so
+                # check type first
+                if isinstance(content, dict):
+                    sender = content.get("user_id")
+
+            room_id = event.get("room_id", None)
+            ev_type = event.get("type", None)
+            is_url = "url" in event.get("content", {})
 
         return self.check_fields(
-            event.get("room_id", None),
+            room_id,
             sender,
-            event.get("type", None),
-            "url" in event.get("content", {})
+            ev_type,
+            is_url,
         )
 
     def check_fields(self, room_id, sender, event_type, contains_url):
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
new file mode 100644
index 0000000000..e4318cdfc3
--- /dev/null
+++ b/synapse/app/_base.py
@@ -0,0 +1,178 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import logging
+import sys
+
+try:
+    import affinity
+except Exception:
+    affinity = None
+
+from daemonize import Daemonize
+from synapse.util import PreserveLoggingContext
+from synapse.util.rlimit import change_resource_limit
+from twisted.internet import error, reactor
+
+logger = logging.getLogger(__name__)
+
+
+def start_worker_reactor(appname, config):
+    """ Run the reactor in the main process
+
+    Daemonizes if necessary, and then configures some resources, before starting
+    the reactor. Pulls configuration from the 'worker' settings in 'config'.
+
+    Args:
+        appname (str): application name which will be sent to syslog
+        config (synapse.config.Config): config object
+    """
+
+    logger = logging.getLogger(config.worker_app)
+
+    start_reactor(
+        appname,
+        config.soft_file_limit,
+        config.gc_thresholds,
+        config.worker_pid_file,
+        config.worker_daemonize,
+        config.worker_cpu_affinity,
+        logger,
+    )
+
+
+def start_reactor(
+        appname,
+        soft_file_limit,
+        gc_thresholds,
+        pid_file,
+        daemonize,
+        cpu_affinity,
+        logger,
+):
+    """ Run the reactor in the main process
+
+    Daemonizes if necessary, and then configures some resources, before starting
+    the reactor
+
+    Args:
+        appname (str): application name which will be sent to syslog
+        soft_file_limit (int):
+        gc_thresholds:
+        pid_file (str): name of pid file to write to if daemonize is True
+        daemonize (bool): true to run the reactor in a background process
+        cpu_affinity (int|None): cpu affinity mask
+        logger (logging.Logger): logger instance to pass to Daemonize
+    """
+
+    def run():
+        # make sure that we run the reactor with the sentinel log context,
+        # otherwise other PreserveLoggingContext instances will get confused
+        # and complain when they see the logcontext arbitrarily swapping
+        # between the sentinel and `run` logcontexts.
+        with PreserveLoggingContext():
+            logger.info("Running")
+            if cpu_affinity is not None:
+                if not affinity:
+                    quit_with_error(
+                        "Missing package 'affinity' required for cpu_affinity\n"
+                        "option\n\n"
+                        "Install by running:\n\n"
+                        "   pip install affinity\n\n"
+                    )
+                logger.info("Setting CPU affinity to %s" % cpu_affinity)
+                affinity.set_process_affinity_mask(0, cpu_affinity)
+            change_resource_limit(soft_file_limit)
+            if gc_thresholds:
+                gc.set_threshold(*gc_thresholds)
+            reactor.run()
+
+    if daemonize:
+        daemon = Daemonize(
+            app=appname,
+            pid=pid_file,
+            action=run,
+            auto_close_fds=False,
+            verbose=True,
+            logger=logger,
+        )
+        daemon.start()
+    else:
+        run()
+
+
+def quit_with_error(error_string):
+    message_lines = error_string.split("\n")
+    line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2
+    sys.stderr.write("*" * line_length + '\n')
+    for line in message_lines:
+        sys.stderr.write(" %s\n" % (line.rstrip(),))
+    sys.stderr.write("*" * line_length + '\n')
+    sys.exit(1)
+
+
+def listen_tcp(bind_addresses, port, factory, backlog=50):
+    """
+    Create a TCP socket for a port and several addresses
+    """
+    for address in bind_addresses:
+        try:
+            reactor.listenTCP(
+                port,
+                factory,
+                backlog,
+                address
+            )
+        except error.CannotListenError as e:
+            check_bind_error(e, address, bind_addresses)
+
+
+def listen_ssl(bind_addresses, port, factory, context_factory, backlog=50):
+    """
+    Create an SSL socket for a port and several addresses
+    """
+    for address in bind_addresses:
+        try:
+            reactor.listenSSL(
+                port,
+                factory,
+                context_factory,
+                backlog,
+                address
+            )
+        except error.CannotListenError as e:
+            check_bind_error(e, address, bind_addresses)
+
+
+def check_bind_error(e, address, bind_addresses):
+    """
+    This method checks an exception occurred while binding on 0.0.0.0.
+    If :: is specified in the bind addresses a warning is shown.
+    The exception is still raised otherwise.
+
+    Binding on both 0.0.0.0 and :: causes an exception on Linux and macOS
+    because :: binds on both IPv4 and IPv6 (as per RFC 3493).
+    When binding on 0.0.0.0 after :: this can safely be ignored.
+
+    Args:
+        e (Exception): Exception that was caught.
+        address (str): Address on which binding was attempted.
+        bind_addresses (list): Addresses on which the service listens.
+    """
+    if address == '0.0.0.0' and '::' in bind_addresses:
+        logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]')
+    else:
+        raise e
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index 1900930053..58f2c9d68c 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -13,37 +13,30 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
+import sys
 
 import synapse
-
-from synapse.server import HomeServer
+from synapse import events
+from synapse.app import _base
 from synapse.config._base import ConfigError
-from synapse.config.logger import setup_logging
 from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
 from synapse.http.site import SynapseSite
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
 from synapse.replication.slave.storage.directory import DirectoryStore
 from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
 from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.server import HomeServer
 from synapse.storage.engines import create_engine
-from synapse.util.async import sleep
 from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, run_in_background
 from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
 from synapse.util.versionstring import get_version_string
-
-from synapse import events
-
 from twisted.internet import reactor, defer
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import gc
+from twisted.web.resource import NoResource
 
 logger = logging.getLogger("synapse.app.appservice")
 
@@ -56,19 +49,6 @@ class AppserviceSlaveStore(
 
 
 class AppserviceServer(HomeServer):
-    def get_db_conn(self, run_new_connection=True):
-        # Any param beginning with cp_ is a parameter for adbapi, and should
-        # not be passed to the database engine.
-        db_params = {
-            k: v for k, v in self.db_config.get("args", {}).items()
-            if not k.startswith("cp_")
-        }
-        db_conn = self.database_engine.module.connect(**db_params)
-
-        if run_new_connection:
-            self.database_engine.on_new_connection(db_conn)
-        return db_conn
-
     def setup(self):
         logger.info("Setting up.")
         self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
@@ -84,19 +64,18 @@ class AppserviceServer(HomeServer):
                 if name == "metrics":
                     resources[METRICS_PREFIX] = MetricsResource(self)
 
-        root_resource = create_resource_tree(resources, Resource())
-
-        for address in bind_addresses:
-            reactor.listenTCP(
-                port,
-                SynapseSite(
-                    "synapse.access.http.%s" % (site_tag,),
-                    site_tag,
-                    listener_config,
-                    root_resource,
-                ),
-                interface=address
+        root_resource = create_resource_tree(resources, NoResource())
+
+        _base.listen_tcp(
+            bind_addresses,
+            port,
+            SynapseSite(
+                "synapse.access.http.%s" % (site_tag,),
+                site_tag,
+                listener_config,
+                root_resource,
             )
+        )
 
         logger.info("Synapse appservice now listening on port %d", port)
 
@@ -105,45 +84,42 @@ class AppserviceServer(HomeServer):
             if listener["type"] == "http":
                 self._listen_http(listener)
             elif listener["type"] == "manhole":
-                bind_addresses = listener["bind_addresses"]
-
-                for address in bind_addresses:
-                    reactor.listenTCP(
-                        listener["port"],
-                        manhole(
-                            username="matrix",
-                            password="rabbithole",
-                            globals={"hs": self},
-                        ),
-                        interface=address
+                _base.listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
                     )
+                )
             else:
                 logger.warn("Unrecognized listener type: %s", listener["type"])
 
+        self.get_tcp_replication().start_replication(self)
+
+    def build_tcp_replication(self):
+        return ASReplicationHandler(self)
+
+
+class ASReplicationHandler(ReplicationClientHandler):
+    def __init__(self, hs):
+        super(ASReplicationHandler, self).__init__(hs.get_datastore())
+        self.appservice_handler = hs.get_application_service_handler()
+
+    def on_rdata(self, stream_name, token, rows):
+        super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
+
+        if stream_name == "events":
+            max_stream_id = self.store.get_room_max_stream_ordering()
+            run_in_background(self._notify_app_services, max_stream_id)
+
     @defer.inlineCallbacks
-    def replicate(self):
-        http_client = self.get_simple_http_client()
-        store = self.get_datastore()
-        replication_url = self.config.worker_replication_url
-        appservice_handler = self.get_application_service_handler()
-
-        @defer.inlineCallbacks
-        def replicate(results):
-            stream = results.get("events")
-            if stream:
-                max_stream_id = stream["position"]
-                yield appservice_handler.notify_interested_services(max_stream_id)
-
-        while True:
-            try:
-                args = store.stream_positions()
-                args["timeout"] = 30000
-                result = yield http_client.get_json(replication_url, args=args)
-                yield store.process_replication(result)
-                replicate(result)
-            except:
-                logger.exception("Error replicating from %r", replication_url)
-                yield sleep(30)
+    def _notify_app_services(self, room_stream_id):
+        try:
+            yield self.appservice_handler.notify_interested_services(room_stream_id)
+        except Exception:
+            logger.exception("Error notifying application services of event")
 
 
 def start(config_options):
@@ -157,7 +133,7 @@ def start(config_options):
 
     assert config.worker_app == "synapse.app.appservice"
 
-    setup_logging(config.worker_log_config, config.worker_log_file)
+    setup_logging(config, use_worker_options=True)
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
@@ -186,33 +162,13 @@ def start(config_options):
     ps.setup()
     ps.start_listening(config.worker_listeners)
 
-    def run():
-        with LoggingContext("run"):
-            logger.info("Running")
-            change_resource_limit(config.soft_file_limit)
-            if config.gc_thresholds:
-                gc.set_threshold(*config.gc_thresholds)
-            reactor.run()
-
     def start():
-        ps.replicate()
         ps.get_datastore().start_profiling()
         ps.get_state_handler().start_caching()
 
     reactor.callWhenRunning(start)
 
-    if config.worker_daemonize:
-        daemon = Daemonize(
-            app="synapse-appservice",
-            pid=config.worker_pid_file,
-            action=run,
-            auto_close_fds=False,
-            verbose=True,
-            logger=logger,
-        )
-        daemon.start()
-    else:
-        run()
+    _base.start_worker_reactor("synapse-appservice", config)
 
 
 if __name__ == '__main__':
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 4d081eccd1..267d34c881 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -13,45 +13,38 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
+import sys
 
 import synapse
-
+from synapse import events
+from synapse.app import _base
 from synapse.config._base import ConfigError
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.logger import setup_logging
-from synapse.http.site import SynapseSite
+from synapse.crypto import context_factory
 from synapse.http.server import JsonResource
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.http.site import SynapseSite
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
+from synapse.replication.slave.storage.directory import DirectoryStore
 from synapse.replication.slave.storage.events import SlavedEventStore
 from synapse.replication.slave.storage.keys import SlavedKeyStore
-from synapse.replication.slave.storage.room import RoomStore
-from synapse.replication.slave.storage.directory import DirectoryStore
 from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.slave.storage.room import RoomStore
+from synapse.replication.slave.storage.transactions import TransactionStore
+from synapse.replication.tcp.client import ReplicationClientHandler
 from synapse.rest.client.v1.room import PublicRoomListRestServlet
 from synapse.server import HomeServer
-from synapse.storage.client_ips import ClientIpStore
 from synapse.storage.engines import create_engine
-from synapse.util.async import sleep
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.logcontext import LoggingContext
 from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
 from synapse.util.versionstring import get_version_string
-from synapse.crypto import context_factory
-
-from synapse import events
-
-
-from twisted.internet import reactor, defer
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import gc
+from twisted.internet import reactor
+from twisted.web.resource import NoResource
 
 logger = logging.getLogger("synapse.app.client_reader")
 
@@ -63,26 +56,14 @@ class ClientReaderSlavedStore(
     DirectoryStore,
     SlavedApplicationServiceStore,
     SlavedRegistrationStore,
+    TransactionStore,
+    SlavedClientIpStore,
     BaseSlavedStore,
-    ClientIpStore,  # After BaseSlavedStore because the constructor is different
 ):
     pass
 
 
 class ClientReaderServer(HomeServer):
-    def get_db_conn(self, run_new_connection=True):
-        # Any param beginning with cp_ is a parameter for adbapi, and should
-        # not be passed to the database engine.
-        db_params = {
-            k: v for k, v in self.db_config.get("args", {}).items()
-            if not k.startswith("cp_")
-        }
-        db_conn = self.database_engine.module.connect(**db_params)
-
-        if run_new_connection:
-            self.database_engine.on_new_connection(db_conn)
-        return db_conn
-
     def setup(self):
         logger.info("Setting up.")
         self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
@@ -107,19 +88,18 @@ class ClientReaderServer(HomeServer):
                         "/_matrix/client/api/v1": resource,
                     })
 
-        root_resource = create_resource_tree(resources, Resource())
-
-        for address in bind_addresses:
-            reactor.listenTCP(
-                port,
-                SynapseSite(
-                    "synapse.access.http.%s" % (site_tag,),
-                    site_tag,
-                    listener_config,
-                    root_resource,
-                ),
-                interface=address
+        root_resource = create_resource_tree(resources, NoResource())
+
+        _base.listen_tcp(
+            bind_addresses,
+            port,
+            SynapseSite(
+                "synapse.access.http.%s" % (site_tag,),
+                site_tag,
+                listener_config,
+                root_resource,
             )
+        )
 
         logger.info("Synapse client reader now listening on port %d", port)
 
@@ -128,36 +108,23 @@ class ClientReaderServer(HomeServer):
             if listener["type"] == "http":
                 self._listen_http(listener)
             elif listener["type"] == "manhole":
-                bind_addresses = listener["bind_addresses"]
-
-                for address in bind_addresses:
-                    reactor.listenTCP(
-                        listener["port"],
-                        manhole(
-                            username="matrix",
-                            password="rabbithole",
-                            globals={"hs": self},
-                        ),
-                        interface=address
+                _base.listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
                     )
+                )
+
             else:
                 logger.warn("Unrecognized listener type: %s", listener["type"])
 
-    @defer.inlineCallbacks
-    def replicate(self):
-        http_client = self.get_simple_http_client()
-        store = self.get_datastore()
-        replication_url = self.config.worker_replication_url
+        self.get_tcp_replication().start_replication(self)
 
-        while True:
-            try:
-                args = store.stream_positions()
-                args["timeout"] = 30000
-                result = yield http_client.get_json(replication_url, args=args)
-                yield store.process_replication(result)
-            except:
-                logger.exception("Error replicating from %r", replication_url)
-                yield sleep(5)
+    def build_tcp_replication(self):
+        return ReplicationClientHandler(self.get_datastore())
 
 
 def start(config_options):
@@ -171,7 +138,7 @@ def start(config_options):
 
     assert config.worker_app == "synapse.app.client_reader"
 
-    setup_logging(config.worker_log_config, config.worker_log_file)
+    setup_logging(config, use_worker_options=True)
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
@@ -189,36 +156,15 @@ def start(config_options):
     )
 
     ss.setup()
-    ss.get_handlers()
     ss.start_listening(config.worker_listeners)
 
-    def run():
-        with LoggingContext("run"):
-            logger.info("Running")
-            change_resource_limit(config.soft_file_limit)
-            if config.gc_thresholds:
-                gc.set_threshold(*config.gc_thresholds)
-            reactor.run()
-
     def start():
         ss.get_state_handler().start_caching()
         ss.get_datastore().start_profiling()
-        ss.replicate()
 
     reactor.callWhenRunning(start)
 
-    if config.worker_daemonize:
-        daemon = Daemonize(
-            app="synapse-client-reader",
-            pid=config.worker_pid_file,
-            action=run,
-            auto_close_fds=False,
-            verbose=True,
-            logger=logger,
-        )
-        daemon.start()
-    else:
-        run()
+    _base.start_worker_reactor("synapse-client-reader", config)
 
 
 if __name__ == '__main__':
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
new file mode 100644
index 0000000000..b915d12d53
--- /dev/null
+++ b/synapse/app/event_creator.py
@@ -0,0 +1,189 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import sys
+
+import synapse
+from synapse import events
+from synapse.app import _base
+from synapse.config._base import ConfigError
+from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
+from synapse.crypto import context_factory
+from synapse.http.server import JsonResource
+from synapse.http.site import SynapseSite
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
+from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
+from synapse.replication.slave.storage.devices import SlavedDeviceStore
+from synapse.replication.slave.storage.directory import DirectoryStore
+from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.replication.slave.storage.profile import SlavedProfileStore
+from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
+from synapse.replication.slave.storage.pushers import SlavedPusherStore
+from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
+from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.slave.storage.room import RoomStore
+from synapse.replication.slave.storage.transactions import TransactionStore
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.rest.client.v1.room import (
+    RoomSendEventRestServlet, RoomMembershipRestServlet, RoomStateEventRestServlet,
+    JoinRoomAliasServlet,
+)
+from synapse.server import HomeServer
+from synapse.storage.engines import create_engine
+from synapse.util.httpresourcetree import create_resource_tree
+from synapse.util.logcontext import LoggingContext
+from synapse.util.manhole import manhole
+from synapse.util.versionstring import get_version_string
+from twisted.internet import reactor
+from twisted.web.resource import NoResource
+
+logger = logging.getLogger("synapse.app.event_creator")
+
+
+class EventCreatorSlavedStore(
+    DirectoryStore,
+    TransactionStore,
+    SlavedProfileStore,
+    SlavedAccountDataStore,
+    SlavedPusherStore,
+    SlavedReceiptsStore,
+    SlavedPushRuleStore,
+    SlavedDeviceStore,
+    SlavedClientIpStore,
+    SlavedApplicationServiceStore,
+    SlavedEventStore,
+    SlavedRegistrationStore,
+    RoomStore,
+    BaseSlavedStore,
+):
+    pass
+
+
+class EventCreatorServer(HomeServer):
+    def setup(self):
+        logger.info("Setting up.")
+        self.datastore = EventCreatorSlavedStore(self.get_db_conn(), self)
+        logger.info("Finished setting up.")
+
+    def _listen_http(self, listener_config):
+        port = listener_config["port"]
+        bind_addresses = listener_config["bind_addresses"]
+        site_tag = listener_config.get("tag", port)
+        resources = {}
+        for res in listener_config["resources"]:
+            for name in res["names"]:
+                if name == "metrics":
+                    resources[METRICS_PREFIX] = MetricsResource(self)
+                elif name == "client":
+                    resource = JsonResource(self, canonical_json=False)
+                    RoomSendEventRestServlet(self).register(resource)
+                    RoomMembershipRestServlet(self).register(resource)
+                    RoomStateEventRestServlet(self).register(resource)
+                    JoinRoomAliasServlet(self).register(resource)
+                    resources.update({
+                        "/_matrix/client/r0": resource,
+                        "/_matrix/client/unstable": resource,
+                        "/_matrix/client/v2_alpha": resource,
+                        "/_matrix/client/api/v1": resource,
+                    })
+
+        root_resource = create_resource_tree(resources, NoResource())
+
+        _base.listen_tcp(
+            bind_addresses,
+            port,
+            SynapseSite(
+                "synapse.access.http.%s" % (site_tag,),
+                site_tag,
+                listener_config,
+                root_resource,
+            )
+        )
+
+        logger.info("Synapse event creator now listening on port %d", port)
+
+    def start_listening(self, listeners):
+        for listener in listeners:
+            if listener["type"] == "http":
+                self._listen_http(listener)
+            elif listener["type"] == "manhole":
+                _base.listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
+                    )
+                )
+            else:
+                logger.warn("Unrecognized listener type: %s", listener["type"])
+
+        self.get_tcp_replication().start_replication(self)
+
+    def build_tcp_replication(self):
+        return ReplicationClientHandler(self.get_datastore())
+
+
+def start(config_options):
+    try:
+        config = HomeServerConfig.load_config(
+            "Synapse event creator", config_options
+        )
+    except ConfigError as e:
+        sys.stderr.write("\n" + e.message + "\n")
+        sys.exit(1)
+
+    assert config.worker_app == "synapse.app.event_creator"
+
+    assert config.worker_replication_http_port is not None
+
+    setup_logging(config, use_worker_options=True)
+
+    events.USE_FROZEN_DICTS = config.use_frozen_dicts
+
+    database_engine = create_engine(config.database_config)
+
+    tls_server_context_factory = context_factory.ServerContextFactory(config)
+
+    ss = EventCreatorServer(
+        config.server_name,
+        db_config=config.database_config,
+        tls_server_context_factory=tls_server_context_factory,
+        config=config,
+        version_string="Synapse/" + get_version_string(synapse),
+        database_engine=database_engine,
+    )
+
+    ss.setup()
+    ss.start_listening(config.worker_listeners)
+
+    def start():
+        ss.get_state_handler().start_caching()
+        ss.get_datastore().start_profiling()
+
+    reactor.callWhenRunning(start)
+
+    _base.start_worker_reactor("synapse-event-creator", config)
+
+
+if __name__ == '__main__':
+    with LoggingContext("main"):
+        start(sys.argv[1:])
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 90a4816753..c1dc66dd17 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -13,43 +13,35 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
+import sys
 
 import synapse
-
+from synapse import events
+from synapse.api.urls import FEDERATION_PREFIX
+from synapse.app import _base
 from synapse.config._base import ConfigError
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.logger import setup_logging
+from synapse.crypto import context_factory
+from synapse.federation.transport.server import TransportLayerServer
 from synapse.http.site import SynapseSite
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
 from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage.directory import DirectoryStore
 from synapse.replication.slave.storage.events import SlavedEventStore
 from synapse.replication.slave.storage.keys import SlavedKeyStore
 from synapse.replication.slave.storage.room import RoomStore
 from synapse.replication.slave.storage.transactions import TransactionStore
-from synapse.replication.slave.storage.directory import DirectoryStore
+from synapse.replication.tcp.client import ReplicationClientHandler
 from synapse.server import HomeServer
 from synapse.storage.engines import create_engine
-from synapse.util.async import sleep
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.logcontext import LoggingContext
 from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
 from synapse.util.versionstring import get_version_string
-from synapse.api.urls import FEDERATION_PREFIX
-from synapse.federation.transport.server import TransportLayerServer
-from synapse.crypto import context_factory
-
-from synapse import events
-
-
-from twisted.internet import reactor, defer
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import gc
+from twisted.internet import reactor
+from twisted.web.resource import NoResource
 
 logger = logging.getLogger("synapse.app.federation_reader")
 
@@ -66,19 +58,6 @@ class FederationReaderSlavedStore(
 
 
 class FederationReaderServer(HomeServer):
-    def get_db_conn(self, run_new_connection=True):
-        # Any param beginning with cp_ is a parameter for adbapi, and should
-        # not be passed to the database engine.
-        db_params = {
-            k: v for k, v in self.db_config.get("args", {}).items()
-            if not k.startswith("cp_")
-        }
-        db_conn = self.database_engine.module.connect(**db_params)
-
-        if run_new_connection:
-            self.database_engine.on_new_connection(db_conn)
-        return db_conn
-
     def setup(self):
         logger.info("Setting up.")
         self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
@@ -98,19 +77,18 @@ class FederationReaderServer(HomeServer):
                         FEDERATION_PREFIX: TransportLayerServer(self),
                     })
 
-        root_resource = create_resource_tree(resources, Resource())
-
-        for address in bind_addresses:
-            reactor.listenTCP(
-                port,
-                SynapseSite(
-                    "synapse.access.http.%s" % (site_tag,),
-                    site_tag,
-                    listener_config,
-                    root_resource,
-                ),
-                interface=address
+        root_resource = create_resource_tree(resources, NoResource())
+
+        _base.listen_tcp(
+            bind_addresses,
+            port,
+            SynapseSite(
+                "synapse.access.http.%s" % (site_tag,),
+                site_tag,
+                listener_config,
+                root_resource,
             )
+        )
 
         logger.info("Synapse federation reader now listening on port %d", port)
 
@@ -119,36 +97,22 @@ class FederationReaderServer(HomeServer):
             if listener["type"] == "http":
                 self._listen_http(listener)
             elif listener["type"] == "manhole":
-                bind_addresses = listener["bind_addresses"]
-
-                for address in bind_addresses:
-                    reactor.listenTCP(
-                        listener["port"],
-                        manhole(
-                            username="matrix",
-                            password="rabbithole",
-                            globals={"hs": self},
-                        ),
-                        interface=address
+                _base.listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
                     )
+                )
             else:
                 logger.warn("Unrecognized listener type: %s", listener["type"])
 
-    @defer.inlineCallbacks
-    def replicate(self):
-        http_client = self.get_simple_http_client()
-        store = self.get_datastore()
-        replication_url = self.config.worker_replication_url
+        self.get_tcp_replication().start_replication(self)
 
-        while True:
-            try:
-                args = store.stream_positions()
-                args["timeout"] = 30000
-                result = yield http_client.get_json(replication_url, args=args)
-                yield store.process_replication(result)
-            except:
-                logger.exception("Error replicating from %r", replication_url)
-                yield sleep(5)
+    def build_tcp_replication(self):
+        return ReplicationClientHandler(self.get_datastore())
 
 
 def start(config_options):
@@ -162,7 +126,7 @@ def start(config_options):
 
     assert config.worker_app == "synapse.app.federation_reader"
 
-    setup_logging(config.worker_log_config, config.worker_log_file)
+    setup_logging(config, use_worker_options=True)
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
@@ -180,36 +144,15 @@ def start(config_options):
     )
 
     ss.setup()
-    ss.get_handlers()
     ss.start_listening(config.worker_listeners)
 
-    def run():
-        with LoggingContext("run"):
-            logger.info("Running")
-            change_resource_limit(config.soft_file_limit)
-            if config.gc_thresholds:
-                gc.set_threshold(*config.gc_thresholds)
-            reactor.run()
-
     def start():
         ss.get_state_handler().start_caching()
         ss.get_datastore().start_profiling()
-        ss.replicate()
 
     reactor.callWhenRunning(start)
 
-    if config.worker_daemonize:
-        daemon = Daemonize(
-            app="synapse-federation-reader",
-            pid=config.worker_pid_file,
-            action=run,
-            auto_close_fds=False,
-            verbose=True,
-            logger=logger,
-        )
-        daemon.start()
-    else:
-        run()
+    _base.start_worker_reactor("synapse-federation-reader", config)
 
 
 if __name__ == '__main__':
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 411e47d98d..a08af83a4c 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -13,69 +13,69 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
+import sys
 
 import synapse
-
-from synapse.server import HomeServer
+from synapse import events
+from synapse.app import _base
 from synapse.config._base import ConfigError
-from synapse.config.logger import setup_logging
 from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
 from synapse.crypto import context_factory
-from synapse.http.site import SynapseSite
 from synapse.federation import send_queue
-from synapse.federation.units import Edu
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.http.site import SynapseSite
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
 from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
+from synapse.replication.slave.storage.devices import SlavedDeviceStore
 from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.replication.slave.storage.presence import SlavedPresenceStore
 from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
 from synapse.replication.slave.storage.registration import SlavedRegistrationStore
 from synapse.replication.slave.storage.transactions import TransactionStore
-from synapse.replication.slave.storage.devices import SlavedDeviceStore
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.server import HomeServer
 from synapse.storage.engines import create_engine
-from synapse.storage.presence import UserPresenceState
-from synapse.util.async import sleep
+from synapse.util.async import Linearizer
 from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, run_in_background
 from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
 from synapse.util.versionstring import get_version_string
+from twisted.internet import defer, reactor
+from twisted.web.resource import NoResource
 
-from synapse import events
-
-from twisted.internet import reactor, defer
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import gc
-import ujson as json
-
-logger = logging.getLogger("synapse.app.appservice")
+logger = logging.getLogger("synapse.app.federation_sender")
 
 
 class FederationSenderSlaveStore(
     SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
-    SlavedRegistrationStore, SlavedDeviceStore,
+    SlavedRegistrationStore, SlavedDeviceStore, SlavedPresenceStore,
 ):
-    pass
+    def __init__(self, db_conn, hs):
+        super(FederationSenderSlaveStore, self).__init__(db_conn, hs)
+
+        # We pull out the current federation stream position now so that we
+        # always have a known value for the federation position in memory so
+        # that we don't have to bounce via a deferred once when we start the
+        # replication streams.
+        self.federation_out_pos_startup = self._get_federation_out_pos(db_conn)
+
+    def _get_federation_out_pos(self, db_conn):
+        sql = (
+            "SELECT stream_id FROM federation_stream_position"
+            " WHERE type = ?"
+        )
+        sql = self.database_engine.convert_param_style(sql)
 
+        txn = db_conn.cursor()
+        txn.execute(sql, ("federation",))
+        rows = txn.fetchall()
+        txn.close()
+
+        return rows[0][0] if rows else -1
 
-class FederationSenderServer(HomeServer):
-    def get_db_conn(self, run_new_connection=True):
-        # Any param beginning with cp_ is a parameter for adbapi, and should
-        # not be passed to the database engine.
-        db_params = {
-            k: v for k, v in self.db_config.get("args", {}).items()
-            if not k.startswith("cp_")
-        }
-        db_conn = self.database_engine.module.connect(**db_params)
-
-        if run_new_connection:
-            self.database_engine.on_new_connection(db_conn)
-        return db_conn
 
+class FederationSenderServer(HomeServer):
     def setup(self):
         logger.info("Setting up.")
         self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
@@ -91,19 +91,18 @@ class FederationSenderServer(HomeServer):
                 if name == "metrics":
                     resources[METRICS_PREFIX] = MetricsResource(self)
 
-        root_resource = create_resource_tree(resources, Resource())
-
-        for address in bind_addresses:
-            reactor.listenTCP(
-                port,
-                SynapseSite(
-                    "synapse.access.http.%s" % (site_tag,),
-                    site_tag,
-                    listener_config,
-                    root_resource,
-                ),
-                interface=address
+        root_resource = create_resource_tree(resources, NoResource())
+
+        _base.listen_tcp(
+            bind_addresses,
+            port,
+            SynapseSite(
+                "synapse.access.http.%s" % (site_tag,),
+                site_tag,
+                listener_config,
+                root_resource,
             )
+        )
 
         logger.info("Synapse federation_sender now listening on port %d", port)
 
@@ -112,41 +111,39 @@ class FederationSenderServer(HomeServer):
             if listener["type"] == "http":
                 self._listen_http(listener)
             elif listener["type"] == "manhole":
-                bind_addresses = listener["bind_addresses"]
-
-                for address in bind_addresses:
-                    reactor.listenTCP(
-                        listener["port"],
-                        manhole(
-                            username="matrix",
-                            password="rabbithole",
-                            globals={"hs": self},
-                        ),
-                        interface=address
+                _base.listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
                     )
+                )
             else:
                 logger.warn("Unrecognized listener type: %s", listener["type"])
 
-    @defer.inlineCallbacks
-    def replicate(self):
-        http_client = self.get_simple_http_client()
-        store = self.get_datastore()
-        replication_url = self.config.worker_replication_url
-        send_handler = FederationSenderHandler(self)
-
-        send_handler.on_start()
-
-        while True:
-            try:
-                args = store.stream_positions()
-                args.update((yield send_handler.stream_positions()))
-                args["timeout"] = 30000
-                result = yield http_client.get_json(replication_url, args=args)
-                yield store.process_replication(result)
-                yield send_handler.process_replication(result)
-            except:
-                logger.exception("Error replicating from %r", replication_url)
-                yield sleep(30)
+        self.get_tcp_replication().start_replication(self)
+
+    def build_tcp_replication(self):
+        return FederationSenderReplicationHandler(self)
+
+
+class FederationSenderReplicationHandler(ReplicationClientHandler):
+    def __init__(self, hs):
+        super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
+        self.send_handler = FederationSenderHandler(hs, self)
+
+    def on_rdata(self, stream_name, token, rows):
+        super(FederationSenderReplicationHandler, self).on_rdata(
+            stream_name, token, rows
+        )
+        self.send_handler.process_replication_rows(stream_name, token, rows)
+
+    def get_streams_to_replicate(self):
+        args = super(FederationSenderReplicationHandler, self).get_streams_to_replicate()
+        args.update(self.send_handler.stream_positions())
+        return args
 
 
 def start(config_options):
@@ -160,7 +157,7 @@ def start(config_options):
 
     assert config.worker_app == "synapse.app.federation_sender"
 
-    setup_logging(config.worker_log_config, config.worker_log_file)
+    setup_logging(config, use_worker_options=True)
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
@@ -192,42 +189,27 @@ def start(config_options):
     ps.setup()
     ps.start_listening(config.worker_listeners)
 
-    def run():
-        with LoggingContext("run"):
-            logger.info("Running")
-            change_resource_limit(config.soft_file_limit)
-            if config.gc_thresholds:
-                gc.set_threshold(*config.gc_thresholds)
-            reactor.run()
-
     def start():
-        ps.replicate()
         ps.get_datastore().start_profiling()
         ps.get_state_handler().start_caching()
 
     reactor.callWhenRunning(start)
-
-    if config.worker_daemonize:
-        daemon = Daemonize(
-            app="synapse-federation-sender",
-            pid=config.worker_pid_file,
-            action=run,
-            auto_close_fds=False,
-            verbose=True,
-            logger=logger,
-        )
-        daemon.start()
-    else:
-        run()
+    _base.start_worker_reactor("synapse-federation-sender", config)
 
 
 class FederationSenderHandler(object):
     """Processes the replication stream and forwards the appropriate entries
     to the federation sender.
     """
-    def __init__(self, hs):
+    def __init__(self, hs, replication_client):
         self.store = hs.get_datastore()
         self.federation_sender = hs.get_federation_sender()
+        self.replication_client = replication_client
+
+        self.federation_position = self.store.federation_out_pos_startup
+        self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
+
+        self._last_ack = self.federation_position
 
         self._room_serials = {}
         self._room_typing = {}
@@ -239,98 +221,38 @@ class FederationSenderHandler(object):
             self.store.get_room_max_stream_ordering()
         )
 
-    @defer.inlineCallbacks
     def stream_positions(self):
-        stream_id = yield self.store.get_federation_out_pos("federation")
-        defer.returnValue({
-            "federation": stream_id,
-
-            # Ack stuff we've "processed", this should only be called from
-            # one process.
-            "federation_ack": stream_id,
-        })
+        return {"federation": self.federation_position}
 
-    @defer.inlineCallbacks
-    def process_replication(self, result):
+    def process_replication_rows(self, stream_name, token, rows):
         # The federation stream contains things that we want to send out, e.g.
         # presence, typing, etc.
-        fed_stream = result.get("federation")
-        if fed_stream:
-            latest_id = int(fed_stream["position"])
-
-            # The federation stream containis a bunch of different types of
-            # rows that need to be handled differently. We parse the rows, put
-            # them into the appropriate collection and then send them off.
-            presence_to_send = {}
-            keyed_edus = {}
-            edus = {}
-            failures = {}
-            device_destinations = set()
-
-            # Parse the rows in the stream
-            for row in fed_stream["rows"]:
-                position, typ, content_js = row
-                content = json.loads(content_js)
-
-                if typ == send_queue.PRESENCE_TYPE:
-                    destination = content["destination"]
-                    state = UserPresenceState.from_dict(content["state"])
-
-                    presence_to_send.setdefault(destination, []).append(state)
-                elif typ == send_queue.KEYED_EDU_TYPE:
-                    key = content["key"]
-                    edu = Edu(**content["edu"])
-
-                    keyed_edus.setdefault(
-                        edu.destination, {}
-                    )[(edu.destination, tuple(key))] = edu
-                elif typ == send_queue.EDU_TYPE:
-                    edu = Edu(**content)
-
-                    edus.setdefault(edu.destination, []).append(edu)
-                elif typ == send_queue.FAILURE_TYPE:
-                    destination = content["destination"]
-                    failure = content["failure"]
-
-                    failures.setdefault(destination, []).append(failure)
-                elif typ == send_queue.DEVICE_MESSAGE_TYPE:
-                    device_destinations.add(content["destination"])
-                else:
-                    raise Exception("Unrecognised federation type: %r", typ)
-
-            # We've finished collecting, send everything off
-            for destination, states in presence_to_send.items():
-                self.federation_sender.send_presence(destination, states)
-
-            for destination, edu_map in keyed_edus.items():
-                for key, edu in edu_map.items():
-                    self.federation_sender.send_edu(
-                        edu.destination, edu.edu_type, edu.content, key=key,
-                    )
-
-            for destination, edu_list in edus.items():
-                for edu in edu_list:
-                    self.federation_sender.send_edu(
-                        edu.destination, edu.edu_type, edu.content, key=None,
-                    )
+        if stream_name == "federation":
+            send_queue.process_rows_for_federation(self.federation_sender, rows)
+            run_in_background(self.update_token, token)
 
-            for destination, failure_list in failures.items():
-                for failure in failure_list:
-                    self.federation_sender.send_failure(destination, failure)
-
-            for destination in device_destinations:
-                self.federation_sender.send_device_messages(destination)
+        # We also need to poke the federation sender when new events happen
+        elif stream_name == "events":
+            self.federation_sender.notify_new_events(token)
 
-            # Record where we are in the stream.
-            yield self.store.update_federation_out_pos(
-                "federation", latest_id
-            )
+    @defer.inlineCallbacks
+    def update_token(self, token):
+        try:
+            self.federation_position = token
+
+            # We linearize here to ensure we don't have races updating the token
+            with (yield self._fed_position_linearizer.queue(None)):
+                if self._last_ack < self.federation_position:
+                    yield self.store.update_federation_out_pos(
+                        "federation", self.federation_position
+                    )
 
-        # We also need to poke the federation sender when new events happen
-        event_stream = result.get("events")
-        if event_stream:
-            latest_pos = event_stream["position"]
-            self.federation_sender.notify_new_events(latest_pos)
+                    # We ACK this token over replication so that the master can drop
+                    # its in memory queues
+                    self.replication_client.send_federation_ack(self.federation_position)
+                    self._last_ack = self.federation_position
+        except Exception:
+            logger.exception("Error updating federation stream position")
 
 
 if __name__ == '__main__':
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
new file mode 100644
index 0000000000..b349e3e3ce
--- /dev/null
+++ b/synapse/app/frontend_proxy.py
@@ -0,0 +1,227 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import sys
+
+import synapse
+from synapse import events
+from synapse.api.errors import SynapseError
+from synapse.app import _base
+from synapse.config._base import ConfigError
+from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
+from synapse.crypto import context_factory
+from synapse.http.server import JsonResource
+from synapse.http.servlet import (
+    RestServlet, parse_json_object_from_request,
+)
+from synapse.http.site import SynapseSite
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
+from synapse.replication.slave.storage.devices import SlavedDeviceStore
+from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.rest.client.v2_alpha._base import client_v2_patterns
+from synapse.server import HomeServer
+from synapse.storage.engines import create_engine
+from synapse.util.httpresourcetree import create_resource_tree
+from synapse.util.logcontext import LoggingContext
+from synapse.util.manhole import manhole
+from synapse.util.versionstring import get_version_string
+from twisted.internet import defer, reactor
+from twisted.web.resource import NoResource
+
+logger = logging.getLogger("synapse.app.frontend_proxy")
+
+
+class KeyUploadServlet(RestServlet):
+    PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
+
+    def __init__(self, hs):
+        """
+        Args:
+            hs (synapse.server.HomeServer): server
+        """
+        super(KeyUploadServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+        self.http_client = hs.get_simple_http_client()
+        self.main_uri = hs.config.worker_main_http_uri
+
+    @defer.inlineCallbacks
+    def on_POST(self, request, device_id):
+        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+        user_id = requester.user.to_string()
+        body = parse_json_object_from_request(request)
+
+        if device_id is not None:
+            # passing the device_id here is deprecated; however, we allow it
+            # for now for compatibility with older clients.
+            if (requester.device_id is not None and
+                    device_id != requester.device_id):
+                logger.warning("Client uploading keys for a different device "
+                               "(logged in as %s, uploading for %s)",
+                               requester.device_id, device_id)
+        else:
+            device_id = requester.device_id
+
+        if device_id is None:
+            raise SynapseError(
+                400,
+                "To upload keys, you must pass device_id when authenticating"
+            )
+
+        if body:
+            # They're actually trying to upload something, proxy to main synapse.
+            # Pass through the auth headers, if any, in case the access token
+            # is there.
+            auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", [])
+            headers = {
+                "Authorization": auth_headers,
+            }
+            result = yield self.http_client.post_json_get_json(
+                self.main_uri + request.uri,
+                body,
+                headers=headers,
+            )
+
+            defer.returnValue((200, result))
+        else:
+            # Just interested in counts.
+            result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
+            defer.returnValue((200, {"one_time_key_counts": result}))
+
+
+class FrontendProxySlavedStore(
+    SlavedDeviceStore,
+    SlavedClientIpStore,
+    SlavedApplicationServiceStore,
+    SlavedRegistrationStore,
+    BaseSlavedStore,
+):
+    pass
+
+
+class FrontendProxyServer(HomeServer):
+    def setup(self):
+        logger.info("Setting up.")
+        self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self)
+        logger.info("Finished setting up.")
+
+    def _listen_http(self, listener_config):
+        port = listener_config["port"]
+        bind_addresses = listener_config["bind_addresses"]
+        site_tag = listener_config.get("tag", port)
+        resources = {}
+        for res in listener_config["resources"]:
+            for name in res["names"]:
+                if name == "metrics":
+                    resources[METRICS_PREFIX] = MetricsResource(self)
+                elif name == "client":
+                    resource = JsonResource(self, canonical_json=False)
+                    KeyUploadServlet(self).register(resource)
+                    resources.update({
+                        "/_matrix/client/r0": resource,
+                        "/_matrix/client/unstable": resource,
+                        "/_matrix/client/v2_alpha": resource,
+                        "/_matrix/client/api/v1": resource,
+                    })
+
+        root_resource = create_resource_tree(resources, NoResource())
+
+        _base.listen_tcp(
+            bind_addresses,
+            port,
+            SynapseSite(
+                "synapse.access.http.%s" % (site_tag,),
+                site_tag,
+                listener_config,
+                root_resource,
+            )
+        )
+
+        logger.info("Synapse client reader now listening on port %d", port)
+
+    def start_listening(self, listeners):
+        for listener in listeners:
+            if listener["type"] == "http":
+                self._listen_http(listener)
+            elif listener["type"] == "manhole":
+                _base.listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
+                    )
+                )
+            else:
+                logger.warn("Unrecognized listener type: %s", listener["type"])
+
+        self.get_tcp_replication().start_replication(self)
+
+    def build_tcp_replication(self):
+        return ReplicationClientHandler(self.get_datastore())
+
+
+def start(config_options):
+    try:
+        config = HomeServerConfig.load_config(
+            "Synapse frontend proxy", config_options
+        )
+    except ConfigError as e:
+        sys.stderr.write("\n" + e.message + "\n")
+        sys.exit(1)
+
+    assert config.worker_app == "synapse.app.frontend_proxy"
+
+    assert config.worker_main_http_uri is not None
+
+    setup_logging(config, use_worker_options=True)
+
+    events.USE_FROZEN_DICTS = config.use_frozen_dicts
+
+    database_engine = create_engine(config.database_config)
+
+    tls_server_context_factory = context_factory.ServerContextFactory(config)
+
+    ss = FrontendProxyServer(
+        config.server_name,
+        db_config=config.database_config,
+        tls_server_context_factory=tls_server_context_factory,
+        config=config,
+        version_string="Synapse/" + get_version_string(synapse),
+        database_engine=database_engine,
+    )
+
+    ss.setup()
+    ss.start_listening(config.worker_listeners)
+
+    def start():
+        ss.get_state_handler().start_caching()
+        ss.get_datastore().start_profiling()
+
+    reactor.callWhenRunning(start)
+
+    _base.start_worker_reactor("synapse-frontend-proxy", config)
+
+
+if __name__ == '__main__':
+    with LoggingContext("main"):
+        start(sys.argv[1:])
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index e0b87468fe..a0e465d644 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -13,59 +13,53 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-import synapse
-
 import gc
 import logging
 import os
 import sys
-from synapse.config._base import ConfigError
-
-from synapse.python_dependencies import (
-    check_requirements, DEPENDENCY_LINKS
-)
-
-from synapse.rest import ClientRestResource
-from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
-from synapse.storage import are_all_users_on_domain
-from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
-
-from synapse.server import HomeServer
 
-from twisted.internet import reactor, task, defer
-from twisted.application import service
-from twisted.web.resource import Resource, EncodingResourceWrapper
-from twisted.web.static import File
-from twisted.web.server import GzipEncoderFactory
+import synapse
+import synapse.config.logger
+from synapse import events
+from synapse.api.urls import CONTENT_REPO_PREFIX, FEDERATION_PREFIX, \
+    LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, SERVER_KEY_PREFIX, SERVER_KEY_V2_PREFIX, \
+    STATIC_PREFIX, WEB_CLIENT_PREFIX
+from synapse.app import _base
+from synapse.app._base import quit_with_error, listen_ssl, listen_tcp
+from synapse.config._base import ConfigError
+from synapse.config.homeserver import HomeServerConfig
+from synapse.crypto import context_factory
+from synapse.federation.transport.server import TransportLayerServer
+from synapse.module_api import ModuleApi
+from synapse.http.additional_resource import AdditionalResource
 from synapse.http.server import RootRedirect
-from synapse.rest.media.v0.content_repository import ContentRepoResource
-from synapse.rest.media.v1.media_repository import MediaRepositoryResource
+from synapse.http.site import SynapseSite
+from synapse.metrics import register_memory_metrics
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.python_dependencies import CONDITIONAL_REQUIREMENTS, \
+    check_requirements
+from synapse.replication.http import ReplicationRestResource, REPLICATION_PREFIX
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.rest import ClientRestResource
 from synapse.rest.key.v1.server_key_resource import LocalKey
 from synapse.rest.key.v2 import KeyApiV2Resource
-from synapse.api.urls import (
-    FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
-    SERVER_KEY_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
-    SERVER_KEY_V2_PREFIX,
-)
-from synapse.config.homeserver import HomeServerConfig
-from synapse.crypto import context_factory
+from synapse.rest.media.v0.content_repository import ContentRepoResource
+from synapse.server import HomeServer
+from synapse.storage import are_all_users_on_domain
+from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
+from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
+from synapse.util.caches import CACHE_SIZE_FACTOR
+from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.logcontext import LoggingContext
-from synapse.metrics import register_memory_metrics, get_metrics_for
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
-from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
-from synapse.federation.transport.server import TransportLayerServer
-
+from synapse.util.manhole import manhole
+from synapse.util.module_loader import load_module
 from synapse.util.rlimit import change_resource_limit
 from synapse.util.versionstring import get_version_string
-from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.manhole import manhole
-
-from synapse.http.site import SynapseSite
-
-from synapse import events
-
-from daemonize import Daemonize
+from twisted.application import service
+from twisted.internet import defer, reactor
+from twisted.web.resource import EncodingResourceWrapper, NoResource
+from twisted.web.server import GzipEncoderFactory
+from twisted.web.static import File
 
 logger = logging.getLogger("synapse.app.homeserver")
 
@@ -90,7 +84,7 @@ def build_resource_for_web_client(hs):
                 "\n"
                 "You can also disable hosting of the webclient via the\n"
                 "configuration option `web_client`\n"
-                % {"dep": DEPENDENCY_LINKS["matrix-angular-sdk"]}
+                % {"dep": CONDITIONAL_REQUIREMENTS["web_client"].keys()[0]}
             )
         syweb_path = os.path.dirname(syweb.__file__)
         webclient_path = os.path.join(syweb_path, "webclient")
@@ -117,90 +111,121 @@ class SynapseHomeServer(HomeServer):
         resources = {}
         for res in listener_config["resources"]:
             for name in res["names"]:
-                if name == "client":
-                    client_resource = ClientRestResource(self)
-                    if res["compress"]:
-                        client_resource = gz_wrap(client_resource)
-
-                    resources.update({
-                        "/_matrix/client/api/v1": client_resource,
-                        "/_matrix/client/r0": client_resource,
-                        "/_matrix/client/unstable": client_resource,
-                        "/_matrix/client/v2_alpha": client_resource,
-                        "/_matrix/client/versions": client_resource,
-                    })
-
-                if name == "federation":
-                    resources.update({
-                        FEDERATION_PREFIX: TransportLayerServer(self),
-                    })
-
-                if name in ["static", "client"]:
-                    resources.update({
-                        STATIC_PREFIX: File(
-                            os.path.join(os.path.dirname(synapse.__file__), "static")
-                        ),
-                    })
-
-                if name in ["media", "federation", "client"]:
-                    media_repo = MediaRepositoryResource(self)
-                    resources.update({
-                        MEDIA_PREFIX: media_repo,
-                        LEGACY_MEDIA_PREFIX: media_repo,
-                        CONTENT_REPO_PREFIX: ContentRepoResource(
-                            self, self.config.uploads_path
-                        ),
-                    })
-
-                if name in ["keys", "federation"]:
-                    resources.update({
-                        SERVER_KEY_PREFIX: LocalKey(self),
-                        SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
-                    })
-
-                if name == "webclient":
-                    resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
-
-                if name == "metrics" and self.get_config().enable_metrics:
-                    resources[METRICS_PREFIX] = MetricsResource(self)
-
-                if name == "replication":
-                    resources[REPLICATION_PREFIX] = ReplicationResource(self)
+                resources.update(self._configure_named_resource(
+                    name, res.get("compress", False),
+                ))
+
+        additional_resources = listener_config.get("additional_resources", {})
+        logger.debug("Configuring additional resources: %r",
+                     additional_resources)
+        module_api = ModuleApi(self, self.get_auth_handler())
+        for path, resmodule in additional_resources.items():
+            handler_cls, config = load_module(resmodule)
+            handler = handler_cls(config, module_api)
+            resources[path] = AdditionalResource(self, handler.handle_request)
 
         if WEB_CLIENT_PREFIX in resources:
             root_resource = RootRedirect(WEB_CLIENT_PREFIX)
         else:
-            root_resource = Resource()
+            root_resource = NoResource()
 
         root_resource = create_resource_tree(resources, root_resource)
 
         if tls:
-            for address in bind_addresses:
-                reactor.listenSSL(
-                    port,
-                    SynapseSite(
-                        "synapse.access.https.%s" % (site_tag,),
-                        site_tag,
-                        listener_config,
-                        root_resource,
-                    ),
-                    self.tls_server_context_factory,
-                    interface=address
-                )
+            listen_ssl(
+                bind_addresses,
+                port,
+                SynapseSite(
+                    "synapse.access.https.%s" % (site_tag,),
+                    site_tag,
+                    listener_config,
+                    root_resource,
+                ),
+                self.tls_server_context_factory,
+            )
+
         else:
-            for address in bind_addresses:
-                reactor.listenTCP(
-                    port,
-                    SynapseSite(
-                        "synapse.access.http.%s" % (site_tag,),
-                        site_tag,
-                        listener_config,
-                        root_resource,
-                    ),
-                    interface=address
+            listen_tcp(
+                bind_addresses,
+                port,
+                SynapseSite(
+                    "synapse.access.http.%s" % (site_tag,),
+                    site_tag,
+                    listener_config,
+                    root_resource,
                 )
+            )
         logger.info("Synapse now listening on port %d", port)
 
+    def _configure_named_resource(self, name, compress=False):
+        """Build a resource map for a named resource
+
+        Args:
+            name (str): named resource: one of "client", "federation", etc
+            compress (bool): whether to enable gzip compression for this
+                resource
+
+        Returns:
+            dict[str, Resource]: map from path to HTTP resource
+        """
+        resources = {}
+        if name == "client":
+            client_resource = ClientRestResource(self)
+            if compress:
+                client_resource = gz_wrap(client_resource)
+
+            resources.update({
+                "/_matrix/client/api/v1": client_resource,
+                "/_matrix/client/r0": client_resource,
+                "/_matrix/client/unstable": client_resource,
+                "/_matrix/client/v2_alpha": client_resource,
+                "/_matrix/client/versions": client_resource,
+            })
+
+        if name == "federation":
+            resources.update({
+                FEDERATION_PREFIX: TransportLayerServer(self),
+            })
+
+        if name in ["static", "client"]:
+            resources.update({
+                STATIC_PREFIX: File(
+                    os.path.join(os.path.dirname(synapse.__file__), "static")
+                ),
+            })
+
+        if name in ["media", "federation", "client"]:
+            if self.get_config().enable_media_repo:
+                media_repo = self.get_media_repository_resource()
+                resources.update({
+                    MEDIA_PREFIX: media_repo,
+                    LEGACY_MEDIA_PREFIX: media_repo,
+                    CONTENT_REPO_PREFIX: ContentRepoResource(
+                        self, self.config.uploads_path
+                    ),
+                })
+            elif name == "media":
+                raise ConfigError(
+                    "'media' resource conflicts with enable_media_repo=False",
+                )
+
+        if name in ["keys", "federation"]:
+            resources.update({
+                SERVER_KEY_PREFIX: LocalKey(self),
+                SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
+            })
+
+        if name == "webclient":
+            resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
+
+        if name == "metrics" and self.get_config().enable_metrics:
+            resources[METRICS_PREFIX] = MetricsResource(self)
+
+        if name == "replication":
+            resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
+
+        return resources
+
     def start_listening(self):
         config = self.get_config()
 
@@ -208,17 +233,24 @@ class SynapseHomeServer(HomeServer):
             if listener["type"] == "http":
                 self._listener_http(config, listener)
             elif listener["type"] == "manhole":
+                listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
+                    )
+                )
+            elif listener["type"] == "replication":
                 bind_addresses = listener["bind_addresses"]
-
                 for address in bind_addresses:
-                    reactor.listenTCP(
-                        listener["port"],
-                        manhole(
-                            username="matrix",
-                            password="rabbithole",
-                            globals={"hs": self},
-                        ),
-                        interface=address
+                    factory = ReplicationStreamProtocolFactory(self)
+                    server_listener = reactor.listenTCP(
+                        listener["port"], factory, interface=address
+                    )
+                    reactor.addSystemEventTrigger(
+                        "before", "shutdown", server_listener.stopListening,
                     )
             else:
                 logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -239,29 +271,6 @@ class SynapseHomeServer(HomeServer):
         except IncorrectDatabaseSetup as e:
             quit_with_error(e.message)
 
-    def get_db_conn(self, run_new_connection=True):
-        # Any param beginning with cp_ is a parameter for adbapi, and should
-        # not be passed to the database engine.
-        db_params = {
-            k: v for k, v in self.db_config.get("args", {}).items()
-            if not k.startswith("cp_")
-        }
-        db_conn = self.database_engine.module.connect(**db_params)
-
-        if run_new_connection:
-            self.database_engine.on_new_connection(db_conn)
-        return db_conn
-
-
-def quit_with_error(error_string):
-    message_lines = error_string.split("\n")
-    line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2
-    sys.stderr.write("*" * line_length + '\n')
-    for line in message_lines:
-        sys.stderr.write(" %s\n" % (line.rstrip(),))
-    sys.stderr.write("*" * line_length + '\n')
-    sys.exit(1)
-
 
 def setup(config_options):
     """
@@ -286,7 +295,7 @@ def setup(config_options):
         # generating config files and shouldn't try to continue.
         sys.exit(0)
 
-    config.setup_logging()
+    synapse.config.logger.setup_logging(config, use_worker_options=False)
 
     # check any extra requirements we have now we have a config
     check_requirements(config)
@@ -340,7 +349,7 @@ def setup(config_options):
         hs.get_state_handler().start_caching()
         hs.get_datastore().start_profiling()
         hs.get_datastore().start_doing_background_updates()
-        hs.get_replication_layer().start_get_pdu_cache()
+        hs.get_federation_client().start_get_pdu_cache()
 
         register_memory_metrics(hs)
 
@@ -389,10 +398,15 @@ def run(hs):
         ThreadPool._worker = profile(ThreadPool._worker)
         reactor.run = profile(reactor.run)
 
-    start_time = hs.get_clock().time()
+    clock = hs.get_clock()
+    start_time = clock.time()
 
     stats = {}
 
+    # Contains the list of processes we will be monitoring
+    # currently either 0 or 1
+    stats_process = []
+
     @defer.inlineCallbacks
     def phone_stats_home():
         logger.info("Gathering stats for reporting")
@@ -401,41 +415,36 @@ def run(hs):
         if uptime < 0:
             uptime = 0
 
-        # If the stats directory is empty then this is the first time we've
-        # reported stats.
-        first_time = not stats
-
         stats["homeserver"] = hs.config.server_name
         stats["timestamp"] = now
         stats["uptime_seconds"] = uptime
         stats["total_users"] = yield hs.get_datastore().count_all_users()
 
+        total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users()
+        stats["total_nonbridged_users"] = total_nonbridged_users
+
         room_count = yield hs.get_datastore().get_room_count()
         stats["total_room_count"] = room_count
 
         stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
-        daily_messages = yield hs.get_datastore().count_daily_messages()
-        if daily_messages is not None:
-            stats["daily_messages"] = daily_messages
-        else:
-            stats.pop("daily_messages", None)
-
-        if first_time:
-            # Add callbacks to report the synapse stats as metrics whenever
-            # prometheus requests them, typically every 30s.
-            # As some of the stats are expensive to calculate we only update
-            # them when synapse phones home to matrix.org every 24 hours.
-            metrics = get_metrics_for("synapse.usage")
-            metrics.add_callback("timestamp", lambda: stats["timestamp"])
-            metrics.add_callback("uptime_seconds", lambda: stats["uptime_seconds"])
-            metrics.add_callback("total_users", lambda: stats["total_users"])
-            metrics.add_callback("total_room_count", lambda: stats["total_room_count"])
-            metrics.add_callback(
-                "daily_active_users", lambda: stats["daily_active_users"]
-            )
-            metrics.add_callback(
-                "daily_messages", lambda: stats.get("daily_messages", 0)
-            )
+        stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms()
+        stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
+
+        r30_results = yield hs.get_datastore().count_r30_users()
+        for name, count in r30_results.iteritems():
+            stats["r30_users_" + name] = count
+
+        daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
+        stats["daily_sent_messages"] = daily_sent_messages
+        stats["cache_factor"] = CACHE_SIZE_FACTOR
+        stats["event_cache_size"] = hs.config.event_cache_size
+
+        if len(stats_process) > 0:
+            stats["memory_rss"] = 0
+            stats["cpu_average"] = 0
+            for process in stats_process:
+                stats["memory_rss"] += process.memory_info().rss
+                stats["cpu_average"] += int(process.cpu_percent(interval=None))
 
         logger.info("Reporting stats to matrix.org: %s" % (stats,))
         try:
@@ -446,37 +455,48 @@ def run(hs):
         except Exception as e:
             logger.warn("Error reporting stats: %s", e)
 
-    if hs.config.report_stats:
-        phone_home_task = task.LoopingCall(phone_stats_home)
-        logger.info("Scheduling stats reporting for 24 hour intervals")
-        phone_home_task.start(60 * 60 * 24, now=False)
-
-    def in_thread():
-        # Uncomment to enable tracing of log context changes.
-        # sys.settrace(logcontext_tracer)
-        with LoggingContext("run"):
-            change_resource_limit(hs.config.soft_file_limit)
-            if hs.config.gc_thresholds:
-                gc.set_threshold(*hs.config.gc_thresholds)
-            reactor.run()
-
-    if hs.config.daemonize:
-
-        if hs.config.print_pidfile:
-            print (hs.config.pid_file)
-
-        daemon = Daemonize(
-            app="synapse-homeserver",
-            pid=hs.config.pid_file,
-            action=lambda: in_thread(),
-            auto_close_fds=False,
-            verbose=True,
-            logger=logger,
-        )
+    def performance_stats_init():
+        try:
+            import psutil
+            process = psutil.Process()
+            # Ensure we can fetch both, and make the initial request for cpu_percent
+            # so the next request will use this as the initial point.
+            process.memory_info().rss
+            process.cpu_percent(interval=None)
+            logger.info("report_stats can use psutil")
+            stats_process.append(process)
+        except (ImportError, AttributeError):
+            logger.warn(
+                "report_stats enabled but psutil is not installed or incorrect version."
+                " Disabling reporting of memory/cpu stats."
+                " Ensuring psutil is available will help matrix.org track performance"
+                " changes across releases."
+            )
 
-        daemon.start()
-    else:
-        in_thread()
+    if hs.config.report_stats:
+        logger.info("Scheduling stats reporting for 3 hour intervals")
+        clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000)
+
+        # We need to defer this init for the cases that we daemonize
+        # otherwise the process ID we get is that of the non-daemon process
+        clock.call_later(0, performance_stats_init)
+
+        # We wait 5 minutes to send the first set of stats as the server can
+        # be quite busy the first few minutes
+        clock.call_later(5 * 60, phone_stats_home)
+
+    if hs.config.daemonize and hs.config.print_pidfile:
+        print (hs.config.pid_file)
+
+    _base.start_reactor(
+        "synapse-homeserver",
+        hs.config.soft_file_limit,
+        hs.config.gc_thresholds,
+        hs.config.pid_file,
+        hs.config.daemonize,
+        hs.config.cpu_affinity,
+        logger,
+    )
 
 
 def main():
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index ef17b158a5..fc8282bbc1 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -13,45 +13,37 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
+import sys
 
 import synapse
-
+from synapse import events
+from synapse.api.urls import (
+    CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX
+)
+from synapse.app import _base
 from synapse.config._base import ConfigError
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.logger import setup_logging
+from synapse.crypto import context_factory
 from synapse.http.site import SynapseSite
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
 from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.slave.storage.transactions import TransactionStore
+from synapse.replication.tcp.client import ReplicationClientHandler
 from synapse.rest.media.v0.content_repository import ContentRepoResource
-from synapse.rest.media.v1.media_repository import MediaRepositoryResource
 from synapse.server import HomeServer
-from synapse.storage.client_ips import ClientIpStore
 from synapse.storage.engines import create_engine
 from synapse.storage.media_repository import MediaRepositoryStore
-from synapse.util.async import sleep
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.logcontext import LoggingContext
 from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
 from synapse.util.versionstring import get_version_string
-from synapse.api.urls import (
-    CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX
-)
-from synapse.crypto import context_factory
-
-from synapse import events
-
-
-from twisted.internet import reactor, defer
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import gc
+from twisted.internet import reactor
+from twisted.web.resource import NoResource
 
 logger = logging.getLogger("synapse.app.media_repository")
 
@@ -59,27 +51,15 @@ logger = logging.getLogger("synapse.app.media_repository")
 class MediaRepositorySlavedStore(
     SlavedApplicationServiceStore,
     SlavedRegistrationStore,
+    SlavedClientIpStore,
+    TransactionStore,
     BaseSlavedStore,
     MediaRepositoryStore,
-    ClientIpStore,
 ):
     pass
 
 
 class MediaRepositoryServer(HomeServer):
-    def get_db_conn(self, run_new_connection=True):
-        # Any param beginning with cp_ is a parameter for adbapi, and should
-        # not be passed to the database engine.
-        db_params = {
-            k: v for k, v in self.db_config.get("args", {}).items()
-            if not k.startswith("cp_")
-        }
-        db_conn = self.database_engine.module.connect(**db_params)
-
-        if run_new_connection:
-            self.database_engine.on_new_connection(db_conn)
-        return db_conn
-
     def setup(self):
         logger.info("Setting up.")
         self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
@@ -95,7 +75,7 @@ class MediaRepositoryServer(HomeServer):
                 if name == "metrics":
                     resources[METRICS_PREFIX] = MetricsResource(self)
                 elif name == "media":
-                    media_repo = MediaRepositoryResource(self)
+                    media_repo = self.get_media_repository_resource()
                     resources.update({
                         MEDIA_PREFIX: media_repo,
                         LEGACY_MEDIA_PREFIX: media_repo,
@@ -104,19 +84,18 @@ class MediaRepositoryServer(HomeServer):
                         ),
                     })
 
-        root_resource = create_resource_tree(resources, Resource())
-
-        for address in bind_addresses:
-            reactor.listenTCP(
-                port,
-                SynapseSite(
-                    "synapse.access.http.%s" % (site_tag,),
-                    site_tag,
-                    listener_config,
-                    root_resource,
-                ),
-                interface=address
+        root_resource = create_resource_tree(resources, NoResource())
+
+        _base.listen_tcp(
+            bind_addresses,
+            port,
+            SynapseSite(
+                "synapse.access.http.%s" % (site_tag,),
+                site_tag,
+                listener_config,
+                root_resource,
             )
+        )
 
         logger.info("Synapse media repository now listening on port %d", port)
 
@@ -125,36 +104,22 @@ class MediaRepositoryServer(HomeServer):
             if listener["type"] == "http":
                 self._listen_http(listener)
             elif listener["type"] == "manhole":
-                bind_addresses = listener["bind_addresses"]
-
-                for address in bind_addresses:
-                    reactor.listenTCP(
-                        listener["port"],
-                        manhole(
-                            username="matrix",
-                            password="rabbithole",
-                            globals={"hs": self},
-                        ),
-                        interface=address
+                _base.listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
                     )
+                )
             else:
                 logger.warn("Unrecognized listener type: %s", listener["type"])
 
-    @defer.inlineCallbacks
-    def replicate(self):
-        http_client = self.get_simple_http_client()
-        store = self.get_datastore()
-        replication_url = self.config.worker_replication_url
+        self.get_tcp_replication().start_replication(self)
 
-        while True:
-            try:
-                args = store.stream_positions()
-                args["timeout"] = 30000
-                result = yield http_client.get_json(replication_url, args=args)
-                yield store.process_replication(result)
-            except:
-                logger.exception("Error replicating from %r", replication_url)
-                yield sleep(5)
+    def build_tcp_replication(self):
+        return ReplicationClientHandler(self.get_datastore())
 
 
 def start(config_options):
@@ -168,7 +133,14 @@ def start(config_options):
 
     assert config.worker_app == "synapse.app.media_repository"
 
-    setup_logging(config.worker_log_config, config.worker_log_file)
+    if config.enable_media_repo:
+        _base.quit_with_error(
+            "enable_media_repo must be disabled in the main synapse process\n"
+            "before the media repo can be run in a separate worker.\n"
+            "Please add ``enable_media_repo: false`` to the main config\n"
+        )
+
+    setup_logging(config, use_worker_options=True)
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
@@ -186,36 +158,15 @@ def start(config_options):
     )
 
     ss.setup()
-    ss.get_handlers()
     ss.start_listening(config.worker_listeners)
 
-    def run():
-        with LoggingContext("run"):
-            logger.info("Running")
-            change_resource_limit(config.soft_file_limit)
-            if config.gc_thresholds:
-                gc.set_threshold(*config.gc_thresholds)
-            reactor.run()
-
     def start():
         ss.get_state_handler().start_caching()
         ss.get_datastore().start_profiling()
-        ss.replicate()
 
     reactor.callWhenRunning(start)
 
-    if config.worker_daemonize:
-        daemon = Daemonize(
-            app="synapse-media-repository",
-            pid=config.worker_pid_file,
-            action=run,
-            auto_close_fds=False,
-            verbose=True,
-            logger=logger,
-        )
-        daemon.start()
-    else:
-        run()
+    _base.start_worker_reactor("synapse-media-repository", config)
 
 
 if __name__ == '__main__':
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 073f2c2489..26930d1b3b 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -13,39 +13,31 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
+import sys
 
 import synapse
-
-from synapse.server import HomeServer
+from synapse import events
+from synapse.app import _base
 from synapse.config._base import ConfigError
-from synapse.config.logger import setup_logging
 from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
 from synapse.http.site import SynapseSite
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
-from synapse.storage.roommember import RoomMemberStore
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
 from synapse.replication.slave.storage.events import SlavedEventStore
 from synapse.replication.slave.storage.pushers import SlavedPusherStore
 from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-from synapse.storage.engines import create_engine
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.server import HomeServer
 from synapse.storage import DataStore
-from synapse.util.async import sleep
+from synapse.storage.engines import create_engine
 from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, run_in_background
 from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
 from synapse.util.versionstring import get_version_string
-
-from synapse import events
-
-from twisted.internet import reactor, defer
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import gc
+from twisted.internet import defer, reactor
+from twisted.web.resource import NoResource
 
 logger = logging.getLogger("synapse.app.pusher")
 
@@ -82,42 +74,15 @@ class PusherSlaveStore(
         DataStore.get_profile_displayname.__func__
     )
 
-    who_forgot_in_room = (
-        RoomMemberStore.__dict__["who_forgot_in_room"]
-    )
-
 
 class PusherServer(HomeServer):
-
-    def get_db_conn(self, run_new_connection=True):
-        # Any param beginning with cp_ is a parameter for adbapi, and should
-        # not be passed to the database engine.
-        db_params = {
-            k: v for k, v in self.db_config.get("args", {}).items()
-            if not k.startswith("cp_")
-        }
-        db_conn = self.database_engine.module.connect(**db_params)
-
-        if run_new_connection:
-            self.database_engine.on_new_connection(db_conn)
-        return db_conn
-
     def setup(self):
         logger.info("Setting up.")
         self.datastore = PusherSlaveStore(self.get_db_conn(), self)
         logger.info("Finished setting up.")
 
     def remove_pusher(self, app_id, push_key, user_id):
-        http_client = self.get_simple_http_client()
-        replication_url = self.config.worker_replication_url
-        url = replication_url + "/remove_pushers"
-        return http_client.post_json_get_json(url, {
-            "remove": [{
-                "app_id": app_id,
-                "push_key": push_key,
-                "user_id": user_id,
-            }]
-        })
+        self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
 
     def _listen_http(self, listener_config):
         port = listener_config["port"]
@@ -129,19 +94,18 @@ class PusherServer(HomeServer):
                 if name == "metrics":
                     resources[METRICS_PREFIX] = MetricsResource(self)
 
-        root_resource = create_resource_tree(resources, Resource())
-
-        for address in bind_addresses:
-            reactor.listenTCP(
-                port,
-                SynapseSite(
-                    "synapse.access.http.%s" % (site_tag,),
-                    site_tag,
-                    listener_config,
-                    root_resource,
-                ),
-                interface=address
+        root_resource = create_resource_tree(resources, NoResource())
+
+        _base.listen_tcp(
+            bind_addresses,
+            port,
+            SynapseSite(
+                "synapse.access.http.%s" % (site_tag,),
+                site_tag,
+                listener_config,
+                root_resource,
             )
+        )
 
         logger.info("Synapse pusher now listening on port %d", port)
 
@@ -150,88 +114,67 @@ class PusherServer(HomeServer):
             if listener["type"] == "http":
                 self._listen_http(listener)
             elif listener["type"] == "manhole":
-                bind_addresses = listener["bind_addresses"]
-
-                for address in bind_addresses:
-                    reactor.listenTCP(
-                        listener["port"],
-                        manhole(
-                            username="matrix",
-                            password="rabbithole",
-                            globals={"hs": self},
-                        ),
-                        interface=address
+                _base.listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
                     )
+                )
             else:
                 logger.warn("Unrecognized listener type: %s", listener["type"])
 
+        self.get_tcp_replication().start_replication(self)
+
+    def build_tcp_replication(self):
+        return PusherReplicationHandler(self)
+
+
+class PusherReplicationHandler(ReplicationClientHandler):
+    def __init__(self, hs):
+        super(PusherReplicationHandler, self).__init__(hs.get_datastore())
+
+        self.pusher_pool = hs.get_pusherpool()
+
+    def on_rdata(self, stream_name, token, rows):
+        super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
+        run_in_background(self.poke_pushers, stream_name, token, rows)
+
     @defer.inlineCallbacks
-    def replicate(self):
-        http_client = self.get_simple_http_client()
-        store = self.get_datastore()
-        replication_url = self.config.worker_replication_url
-        pusher_pool = self.get_pusherpool()
-
-        def stop_pusher(user_id, app_id, pushkey):
-            key = "%s:%s" % (app_id, pushkey)
-            pushers_for_user = pusher_pool.pushers.get(user_id, {})
-            pusher = pushers_for_user.pop(key, None)
-            if pusher is None:
-                return
-            logger.info("Stopping pusher %r / %r", user_id, key)
-            pusher.on_stop()
-
-        def start_pusher(user_id, app_id, pushkey):
-            key = "%s:%s" % (app_id, pushkey)
-            logger.info("Starting pusher %r / %r", user_id, key)
-            return pusher_pool._refresh_pusher(app_id, pushkey, user_id)
-
-        @defer.inlineCallbacks
-        def poke_pushers(results):
-            pushers_rows = set(
-                map(tuple, results.get("pushers", {}).get("rows", []))
-            )
-            deleted_pushers_rows = set(
-                map(tuple, results.get("deleted_pushers", {}).get("rows", []))
-            )
-            for row in sorted(pushers_rows | deleted_pushers_rows):
-                if row in deleted_pushers_rows:
-                    user_id, app_id, pushkey = row[1:4]
-                    stop_pusher(user_id, app_id, pushkey)
-                elif row in pushers_rows:
-                    user_id = row[1]
-                    app_id = row[5]
-                    pushkey = row[8]
-                    yield start_pusher(user_id, app_id, pushkey)
-
-            stream = results.get("events")
-            if stream and stream["rows"]:
-                min_stream_id = stream["rows"][0][0]
-                max_stream_id = stream["position"]
-                preserve_fn(pusher_pool.on_new_notifications)(
-                    min_stream_id, max_stream_id
+    def poke_pushers(self, stream_name, token, rows):
+        try:
+            if stream_name == "pushers":
+                for row in rows:
+                    if row.deleted:
+                        yield self.stop_pusher(row.user_id, row.app_id, row.pushkey)
+                    else:
+                        yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
+            elif stream_name == "events":
+                yield self.pusher_pool.on_new_notifications(
+                    token, token,
                 )
-
-            stream = results.get("receipts")
-            if stream and stream["rows"]:
-                rows = stream["rows"]
-                affected_room_ids = set(row[1] for row in rows)
-                min_stream_id = rows[0][0]
-                max_stream_id = stream["position"]
-                preserve_fn(pusher_pool.on_new_receipts)(
-                    min_stream_id, max_stream_id, affected_room_ids
+            elif stream_name == "receipts":
+                yield self.pusher_pool.on_new_receipts(
+                    token, token, set(row.room_id for row in rows)
                 )
+        except Exception:
+            logger.exception("Error poking pushers")
+
+    def stop_pusher(self, user_id, app_id, pushkey):
+        key = "%s:%s" % (app_id, pushkey)
+        pushers_for_user = self.pusher_pool.pushers.get(user_id, {})
+        pusher = pushers_for_user.pop(key, None)
+        if pusher is None:
+            return
+        logger.info("Stopping pusher %r / %r", user_id, key)
+        pusher.on_stop()
 
-        while True:
-            try:
-                args = store.stream_positions()
-                args["timeout"] = 30000
-                result = yield http_client.get_json(replication_url, args=args)
-                yield store.process_replication(result)
-                poke_pushers(result)
-            except:
-                logger.exception("Error replicating from %r", replication_url)
-                yield sleep(30)
+    def start_pusher(self, user_id, app_id, pushkey):
+        key = "%s:%s" % (app_id, pushkey)
+        logger.info("Starting pusher %r / %r", user_id, key)
+        return self.pusher_pool._refresh_pusher(app_id, pushkey, user_id)
 
 
 def start(config_options):
@@ -245,7 +188,7 @@ def start(config_options):
 
     assert config.worker_app == "synapse.app.pusher"
 
-    setup_logging(config.worker_log_config, config.worker_log_file)
+    setup_logging(config, use_worker_options=True)
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
@@ -274,34 +217,14 @@ def start(config_options):
     ps.setup()
     ps.start_listening(config.worker_listeners)
 
-    def run():
-        with LoggingContext("run"):
-            logger.info("Running")
-            change_resource_limit(config.soft_file_limit)
-            if config.gc_thresholds:
-                gc.set_threshold(*config.gc_thresholds)
-            reactor.run()
-
     def start():
-        ps.replicate()
         ps.get_pusherpool().start()
         ps.get_datastore().start_profiling()
         ps.get_state_handler().start_caching()
 
     reactor.callWhenRunning(start)
 
-    if config.worker_daemonize:
-        daemon = Daemonize(
-            app="synapse-pusher",
-            pid=config.worker_pid_file,
-            action=run,
-            auto_close_fds=False,
-            verbose=True,
-            logger=logger,
-        )
-        daemon.start()
-    else:
-        run()
+    _base.start_worker_reactor("synapse-pusher", config)
 
 
 if __name__ == '__main__':
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 3f29595256..7152b1deb4 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -13,105 +13,87 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import contextlib
+import logging
+import sys
 
 import synapse
-
-from synapse.api.constants import EventTypes, PresenceState
+from synapse.api.constants import EventTypes
+from synapse.app import _base
 from synapse.config._base import ConfigError
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.logger import setup_logging
-from synapse.events import FrozenEvent
-from synapse.handlers.presence import PresenceHandler
-from synapse.http.site import SynapseSite
+from synapse.handlers.presence import PresenceHandler, get_interested_parties
 from synapse.http.server import JsonResource
-from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
-from synapse.rest.client.v2_alpha import sync
-from synapse.rest.client.v1 import events
-from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
-from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
+from synapse.http.site import SynapseSite
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
 from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
 from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
 from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.slave.storage.filtering import SlavedFilteringStore
-from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
-from synapse.replication.slave.storage.presence import SlavedPresenceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
 from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
 from synapse.replication.slave.storage.devices import SlavedDeviceStore
+from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.replication.slave.storage.filtering import SlavedFilteringStore
+from synapse.replication.slave.storage.presence import SlavedPresenceStore
+from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
+from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
+from synapse.replication.slave.storage.registration import SlavedRegistrationStore
 from synapse.replication.slave.storage.room import RoomStore
+from synapse.replication.slave.storage.groups import SlavedGroupServerStore
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.rest.client.v1 import events
+from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
+from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
+from synapse.rest.client.v2_alpha import sync
 from synapse.server import HomeServer
-from synapse.storage.client_ips import ClientIpStore
 from synapse.storage.engines import create_engine
-from synapse.storage.presence import PresenceStore, UserPresenceState
+from synapse.storage.presence import UserPresenceState
 from synapse.storage.roommember import RoomMemberStore
-from synapse.util.async import sleep
 from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, run_in_background
 from synapse.util.manhole import manhole
-from synapse.util.rlimit import change_resource_limit
 from synapse.util.stringutils import random_string
 from synapse.util.versionstring import get_version_string
+from twisted.internet import defer, reactor
+from twisted.web.resource import NoResource
 
-from twisted.internet import reactor, defer
-from twisted.web.resource import Resource
-
-from daemonize import Daemonize
-
-import sys
-import logging
-import contextlib
-import gc
-import ujson as json
+from six import iteritems
 
 logger = logging.getLogger("synapse.app.synchrotron")
 
 
 class SynchrotronSlavedStore(
-    SlavedPushRuleStore,
-    SlavedEventStore,
     SlavedReceiptsStore,
     SlavedAccountDataStore,
     SlavedApplicationServiceStore,
     SlavedRegistrationStore,
     SlavedFilteringStore,
     SlavedPresenceStore,
+    SlavedGroupServerStore,
     SlavedDeviceInboxStore,
     SlavedDeviceStore,
+    SlavedPushRuleStore,
+    SlavedEventStore,
+    SlavedClientIpStore,
     RoomStore,
     BaseSlavedStore,
-    ClientIpStore,  # After BaseSlavedStore because the constructor is different
 ):
-    who_forgot_in_room = (
-        RoomMemberStore.__dict__["who_forgot_in_room"]
-    )
-
     did_forget = (
         RoomMemberStore.__dict__["did_forget"]
     )
 
-    # XXX: This is a bit broken because we don't persist the accepted list in a
-    # way that can be replicated. This means that we don't have a way to
-    # invalidate the cache correctly.
-    get_presence_list_accepted = PresenceStore.__dict__[
-        "get_presence_list_accepted"
-    ]
-    get_presence_list_observers_accepted = PresenceStore.__dict__[
-        "get_presence_list_observers_accepted"
-    ]
-
 
 UPDATE_SYNCING_USERS_MS = 10 * 1000
 
 
 class SynchrotronPresence(object):
     def __init__(self, hs):
+        self.hs = hs
         self.is_mine_id = hs.is_mine_id
         self.http_client = hs.get_simple_http_client()
         self.store = hs.get_datastore()
         self.user_to_num_current_syncs = {}
-        self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
         self.clock = hs.get_clock()
         self.notifier = hs.get_notifier()
 
@@ -121,17 +103,52 @@ class SynchrotronPresence(object):
             for state in active_presence
         }
 
-        self.process_id = random_string(16)
-        logger.info("Presence process_id is %r", self.process_id)
+        # user_id -> last_sync_ms. Lists the users that have stopped syncing
+        # but we haven't notified the master of that yet
+        self.users_going_offline = {}
 
-        self._sending_sync = False
-        self._need_to_send_sync = False
-        self.clock.looping_call(
-            self._send_syncing_users_regularly,
-            UPDATE_SYNCING_USERS_MS,
+        self._send_stop_syncing_loop = self.clock.looping_call(
+            self.send_stop_syncing, 10 * 1000
         )
 
-        reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
+        self.process_id = random_string(16)
+        logger.info("Presence process_id is %r", self.process_id)
+
+    def send_user_sync(self, user_id, is_syncing, last_sync_ms):
+        self.hs.get_tcp_replication().send_user_sync(user_id, is_syncing, last_sync_ms)
+
+    def mark_as_coming_online(self, user_id):
+        """A user has started syncing. Send a UserSync to the master, unless they
+        had recently stopped syncing.
+
+        Args:
+            user_id (str)
+        """
+        going_offline = self.users_going_offline.pop(user_id, None)
+        if not going_offline:
+            # Safe to skip because we haven't yet told the master they were offline
+            self.send_user_sync(user_id, True, self.clock.time_msec())
+
+    def mark_as_going_offline(self, user_id):
+        """A user has stopped syncing. We wait before notifying the master as
+        its likely they'll come back soon. This allows us to avoid sending
+        a stopped syncing immediately followed by a started syncing notification
+        to the master
+
+        Args:
+            user_id (str)
+        """
+        self.users_going_offline[user_id] = self.clock.time_msec()
+
+    def send_stop_syncing(self):
+        """Check if there are any users who have stopped syncing a while ago
+        and haven't come back yet. If there are poke the master about them.
+        """
+        now = self.clock.time_msec()
+        for user_id, last_sync_ms in self.users_going_offline.items():
+            if now - last_sync_ms > 10 * 1000:
+                self.users_going_offline.pop(user_id, None)
+                self.send_user_sync(user_id, False, last_sync_ms)
 
     def set_state(self, user, state, ignore_status_msg=False):
         # TODO Hows this supposed to work?
@@ -139,18 +156,16 @@ class SynchrotronPresence(object):
 
     get_states = PresenceHandler.get_states.__func__
     get_state = PresenceHandler.get_state.__func__
-    _get_interested_parties = PresenceHandler._get_interested_parties.__func__
     current_state_for_users = PresenceHandler.current_state_for_users.__func__
 
-    @defer.inlineCallbacks
     def user_syncing(self, user_id, affect_presence):
         if affect_presence:
             curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
             self.user_to_num_current_syncs[user_id] = curr_sync + 1
-            prev_states = yield self.current_state_for_users([user_id])
-            if prev_states[user_id].state == PresenceState.OFFLINE:
-                # TODO: Don't block the sync request on this HTTP hit.
-                yield self._send_syncing_users_now()
+
+            # If we went from no in flight sync to some, notify replication
+            if self.user_to_num_current_syncs[user_id] == 1:
+                self.mark_as_coming_online(user_id)
 
         def _end():
             # We check that the user_id is in user_to_num_current_syncs because
@@ -159,6 +174,10 @@ class SynchrotronPresence(object):
             if affect_presence and user_id in self.user_to_num_current_syncs:
                 self.user_to_num_current_syncs[user_id] -= 1
 
+                # If we went from one in flight sync to non, notify replication
+                if self.user_to_num_current_syncs[user_id] == 0:
+                    self.mark_as_going_offline(user_id)
+
         @contextlib.contextmanager
         def _user_syncing():
             try:
@@ -166,56 +185,12 @@ class SynchrotronPresence(object):
             finally:
                 _end()
 
-        defer.returnValue(_user_syncing())
-
-    @defer.inlineCallbacks
-    def _on_shutdown(self):
-        # When the synchrotron is shutdown tell the master to clear the in
-        # progress syncs for this process
-        self.user_to_num_current_syncs.clear()
-        yield self._send_syncing_users_now()
-
-    def _send_syncing_users_regularly(self):
-        # Only send an update if we aren't in the middle of sending one.
-        if not self._sending_sync:
-            preserve_fn(self._send_syncing_users_now)()
-
-    @defer.inlineCallbacks
-    def _send_syncing_users_now(self):
-        if self._sending_sync:
-            # We don't want to race with sending another update.
-            # Instead we wait for that update to finish and send another
-            # update afterwards.
-            self._need_to_send_sync = True
-            return
-
-        # Flag that we are sending an update.
-        self._sending_sync = True
-
-        yield self.http_client.post_json_get_json(self.syncing_users_url, {
-            "process_id": self.process_id,
-            "syncing_users": [
-                user_id for user_id, count in self.user_to_num_current_syncs.items()
-                if count > 0
-            ],
-        })
-
-        # Unset the flag as we are no longer sending an update.
-        self._sending_sync = False
-        if self._need_to_send_sync:
-            # If something happened while we were sending the update then
-            # we might need to send another update.
-            # TODO: Check if the update that was sent matches the current state
-            # as we only need to send an update if they are different.
-            self._need_to_send_sync = False
-            yield self._send_syncing_users_now()
+        return defer.succeed(_user_syncing())
 
     @defer.inlineCallbacks
     def notify_from_replication(self, states, stream_id):
-        parties = yield self._get_interested_parties(
-            states, calculate_remote_hosts=False
-        )
-        room_ids_to_states, users_to_states, _ = parties
+        parties = yield get_interested_parties(self.store, states)
+        room_ids_to_states, users_to_states = parties
 
         self.notifier.on_new_event(
             "presence_key", stream_id, rooms=room_ids_to_states.keys(),
@@ -223,26 +198,24 @@ class SynchrotronPresence(object):
         )
 
     @defer.inlineCallbacks
-    def process_replication(self, result):
-        stream = result.get("presence", {"rows": []})
-        states = []
-        for row in stream["rows"]:
-            (
-                position, user_id, state, last_active_ts,
-                last_federation_update_ts, last_user_sync_ts, status_msg,
-                currently_active
-            ) = row
-            state = UserPresenceState(
-                user_id, state, last_active_ts,
-                last_federation_update_ts, last_user_sync_ts, status_msg,
-                currently_active
-            )
-            self.user_to_current_state[user_id] = state
-            states.append(state)
+    def process_replication_rows(self, token, rows):
+        states = [UserPresenceState(
+            row.user_id, row.state, row.last_active_ts,
+            row.last_federation_update_ts, row.last_user_sync_ts, row.status_msg,
+            row.currently_active
+        ) for row in rows]
+
+        for state in states:
+            self.user_to_current_state[row.user_id] = state
+
+        stream_id = token
+        yield self.notify_from_replication(states, stream_id)
 
-        if states and "position" in stream:
-            stream_id = int(stream["position"])
-            yield self.notify_from_replication(states, stream_id)
+    def get_currently_syncing_users(self):
+        return [
+            user_id for user_id, count in iteritems(self.user_to_num_current_syncs)
+            if count > 0
+        ]
 
 
 class SynchrotronTyping(object):
@@ -257,16 +230,12 @@ class SynchrotronTyping(object):
         # value which we *must* use for the next replication request.
         return {"typing": self._latest_room_serial}
 
-    def process_replication(self, result):
-        stream = result.get("typing")
-        if stream:
-            self._latest_room_serial = int(stream["position"])
+    def process_replication_rows(self, token, rows):
+        self._latest_room_serial = token
 
-            for row in stream["rows"]:
-                position, room_id, typing_json = row
-                typing = json.loads(typing_json)
-                self._room_serials[room_id] = position
-                self._room_typing[room_id] = typing
+        for row in rows:
+            self._room_serials[row.room_id] = token
+            self._room_typing[row.room_id] = row.user_ids
 
 
 class SynchrotronApplicationService(object):
@@ -275,19 +244,6 @@ class SynchrotronApplicationService(object):
 
 
 class SynchrotronServer(HomeServer):
-    def get_db_conn(self, run_new_connection=True):
-        # Any param beginning with cp_ is a parameter for adbapi, and should
-        # not be passed to the database engine.
-        db_params = {
-            k: v for k, v in self.db_config.get("args", {}).items()
-            if not k.startswith("cp_")
-        }
-        db_conn = self.database_engine.module.connect(**db_params)
-
-        if run_new_connection:
-            self.database_engine.on_new_connection(db_conn)
-        return db_conn
-
     def setup(self):
         logger.info("Setting up.")
         self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
@@ -315,19 +271,18 @@ class SynchrotronServer(HomeServer):
                         "/_matrix/client/api/v1": resource,
                     })
 
-        root_resource = create_resource_tree(resources, Resource())
-
-        for address in bind_addresses:
-            reactor.listenTCP(
-                port,
-                SynapseSite(
-                    "synapse.access.http.%s" % (site_tag,),
-                    site_tag,
-                    listener_config,
-                    root_resource,
-                ),
-                interface=address
+        root_resource = create_resource_tree(resources, NoResource())
+
+        _base.listen_tcp(
+            bind_addresses,
+            port,
+            SynapseSite(
+                "synapse.access.http.%s" % (site_tag,),
+                site_tag,
+                listener_config,
+                root_resource,
             )
+        )
 
         logger.info("Synapse synchrotron now listening on port %d", port)
 
@@ -336,135 +291,106 @@ class SynchrotronServer(HomeServer):
             if listener["type"] == "http":
                 self._listen_http(listener)
             elif listener["type"] == "manhole":
-                bind_addresses = listener["bind_addresses"]
-
-                for address in bind_addresses:
-                    reactor.listenTCP(
-                        listener["port"],
-                        manhole(
-                            username="matrix",
-                            password="rabbithole",
-                            globals={"hs": self},
-                        ),
-                        interface=address
+                _base.listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
                     )
+                )
             else:
                 logger.warn("Unrecognized listener type: %s", listener["type"])
 
-    @defer.inlineCallbacks
-    def replicate(self):
-        http_client = self.get_simple_http_client()
-        store = self.get_datastore()
-        replication_url = self.config.worker_replication_url
-        notifier = self.get_notifier()
-        presence_handler = self.get_presence_handler()
-        typing_handler = self.get_typing_handler()
-
-        def notify_from_stream(
-            result, stream_name, stream_key, room=None, user=None
-        ):
-            stream = result.get(stream_name)
-            if stream:
-                position_index = stream["field_names"].index("position")
-                if room:
-                    room_index = stream["field_names"].index(room)
-                if user:
-                    user_index = stream["field_names"].index(user)
-
-                users = ()
-                rooms = ()
-                for row in stream["rows"]:
-                    position = row[position_index]
-
-                    if user:
-                        users = (row[user_index],)
-
-                    if room:
-                        rooms = (row[room_index],)
-
-                    notifier.on_new_event(
-                        stream_key, position, users=users, rooms=rooms
-                    )
+        self.get_tcp_replication().start_replication(self)
 
-        @defer.inlineCallbacks
-        def notify_device_list_update(result):
-            stream = result.get("device_lists")
-            if not stream:
-                return
+    def build_tcp_replication(self):
+        return SyncReplicationHandler(self)
 
-            position_index = stream["field_names"].index("position")
-            user_index = stream["field_names"].index("user_id")
+    def build_presence_handler(self):
+        return SynchrotronPresence(self)
 
-            for row in stream["rows"]:
-                position = row[position_index]
-                user_id = row[user_index]
+    def build_typing_handler(self):
+        return SynchrotronTyping(self)
 
-                rooms = yield store.get_rooms_for_user(user_id)
-                room_ids = [r.room_id for r in rooms]
 
-                notifier.on_new_event(
-                    "device_list_key", position, rooms=room_ids,
-                )
+class SyncReplicationHandler(ReplicationClientHandler):
+    def __init__(self, hs):
+        super(SyncReplicationHandler, self).__init__(hs.get_datastore())
 
-        @defer.inlineCallbacks
-        def notify(result):
-            stream = result.get("events")
-            if stream:
-                max_position = stream["position"]
-                for row in stream["rows"]:
-                    position = row[0]
-                    internal = json.loads(row[1])
-                    event_json = json.loads(row[2])
-                    event = FrozenEvent(event_json, internal_metadata_dict=internal)
-                    extra_users = ()
-                    if event.type == EventTypes.Member:
-                        extra_users = (event.state_key,)
-                    notifier.on_new_room_event(
-                        event, position, max_position, extra_users
-                    )
+        self.store = hs.get_datastore()
+        self.typing_handler = hs.get_typing_handler()
+        # NB this is a SynchrotronPresence, not a normal PresenceHandler
+        self.presence_handler = hs.get_presence_handler()
+        self.notifier = hs.get_notifier()
 
-            notify_from_stream(
-                result, "push_rules", "push_rules_key", user="user_id"
-            )
-            notify_from_stream(
-                result, "user_account_data", "account_data_key", user="user_id"
-            )
-            notify_from_stream(
-                result, "room_account_data", "account_data_key", user="user_id"
-            )
-            notify_from_stream(
-                result, "tag_account_data", "account_data_key", user="user_id"
-            )
-            notify_from_stream(
-                result, "receipts", "receipt_key", room="room_id"
-            )
-            notify_from_stream(
-                result, "typing", "typing_key", room="room_id"
-            )
-            notify_from_stream(
-                result, "to_device", "to_device_key", user="user_id"
-            )
-            yield notify_device_list_update(result)
+    def on_rdata(self, stream_name, token, rows):
+        super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
+        run_in_background(self.process_and_notify, stream_name, token, rows)
 
-        while True:
-            try:
-                args = store.stream_positions()
-                args.update(typing_handler.stream_positions())
-                args["timeout"] = 30000
-                result = yield http_client.get_json(replication_url, args=args)
-                yield store.process_replication(result)
-                typing_handler.process_replication(result)
-                yield presence_handler.process_replication(result)
-                yield notify(result)
-            except:
-                logger.exception("Error replicating from %r", replication_url)
-                yield sleep(5)
+    def get_streams_to_replicate(self):
+        args = super(SyncReplicationHandler, self).get_streams_to_replicate()
+        args.update(self.typing_handler.stream_positions())
+        return args
 
-    def build_presence_handler(self):
-        return SynchrotronPresence(self)
+    def get_currently_syncing_users(self):
+        return self.presence_handler.get_currently_syncing_users()
 
-    def build_typing_handler(self):
-        return SynchrotronTyping(self)
+    @defer.inlineCallbacks
+    def process_and_notify(self, stream_name, token, rows):
+        try:
+            if stream_name == "events":
+                # We shouldn't get multiple rows per token for events stream, so
+                # we don't need to optimise this for multiple rows.
+                for row in rows:
+                    event = yield self.store.get_event(row.event_id)
+                    extra_users = ()
+                    if event.type == EventTypes.Member:
+                        extra_users = (event.state_key,)
+                    max_token = self.store.get_room_max_stream_ordering()
+                    self.notifier.on_new_room_event(
+                        event, token, max_token, extra_users
+                    )
+            elif stream_name == "push_rules":
+                self.notifier.on_new_event(
+                    "push_rules_key", token, users=[row.user_id for row in rows],
+                )
+            elif stream_name in ("account_data", "tag_account_data",):
+                self.notifier.on_new_event(
+                    "account_data_key", token, users=[row.user_id for row in rows],
+                )
+            elif stream_name == "receipts":
+                self.notifier.on_new_event(
+                    "receipt_key", token, rooms=[row.room_id for row in rows],
+                )
+            elif stream_name == "typing":
+                self.typing_handler.process_replication_rows(token, rows)
+                self.notifier.on_new_event(
+                    "typing_key", token, rooms=[row.room_id for row in rows],
+                )
+            elif stream_name == "to_device":
+                entities = [row.entity for row in rows if row.entity.startswith("@")]
+                if entities:
+                    self.notifier.on_new_event(
+                        "to_device_key", token, users=entities,
+                    )
+            elif stream_name == "device_lists":
+                all_room_ids = set()
+                for row in rows:
+                    room_ids = yield self.store.get_rooms_for_user(row.user_id)
+                    all_room_ids.update(room_ids)
+                self.notifier.on_new_event(
+                    "device_list_key", token, rooms=all_room_ids,
+                )
+            elif stream_name == "presence":
+                yield self.presence_handler.process_replication_rows(token, rows)
+            elif stream_name == "receipts":
+                self.notifier.on_new_event(
+                    "groups_key", token, users=[row.user_id for row in rows],
+                )
+        except Exception:
+            logger.exception("Error processing replication")
 
 
 def start(config_options):
@@ -478,7 +404,7 @@ def start(config_options):
 
     assert config.worker_app == "synapse.app.synchrotron"
 
-    setup_logging(config.worker_log_config, config.worker_log_file)
+    setup_logging(config, use_worker_options=True)
 
     synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
@@ -496,33 +422,13 @@ def start(config_options):
     ss.setup()
     ss.start_listening(config.worker_listeners)
 
-    def run():
-        with LoggingContext("run"):
-            logger.info("Running")
-            change_resource_limit(config.soft_file_limit)
-            if config.gc_thresholds:
-                gc.set_threshold(*config.gc_thresholds)
-            reactor.run()
-
     def start():
         ss.get_datastore().start_profiling()
-        ss.replicate()
         ss.get_state_handler().start_caching()
 
     reactor.callWhenRunning(start)
 
-    if config.worker_daemonize:
-        daemon = Daemonize(
-            app="synapse-synchrotron",
-            pid=config.worker_pid_file,
-            action=run,
-            auto_close_fds=False,
-            verbose=True,
-            logger=logger,
-        )
-        daemon.start()
-    else:
-        run()
+    _base.start_worker_reactor("synapse-synchrotron", config)
 
 
 if __name__ == '__main__':
diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py
index c045588866..712dfa870e 100755
--- a/synapse/app/synctl.py
+++ b/synapse/app/synctl.py
@@ -23,14 +23,27 @@ import signal
 import subprocess
 import sys
 import yaml
+import errno
+import time
 
 SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
 
 GREEN = "\x1b[1;32m"
+YELLOW = "\x1b[1;33m"
 RED = "\x1b[1;31m"
 NORMAL = "\x1b[m"
 
 
+def pid_running(pid):
+    try:
+        os.kill(pid, 0)
+        return True
+    except OSError as err:
+        if err.errno == errno.EPERM:
+            return True
+        return False
+
+
 def write(message, colour=NORMAL, stream=sys.stdout):
     if colour == NORMAL:
         stream.write(message + "\n")
@@ -38,6 +51,11 @@ def write(message, colour=NORMAL, stream=sys.stdout):
         stream.write(colour + message + NORMAL + "\n")
 
 
+def abort(message, colour=RED, stream=sys.stderr):
+    write(message, colour, stream)
+    sys.exit(1)
+
+
 def start(configfile):
     write("Starting ...")
     args = SYNAPSE
@@ -45,7 +63,8 @@ def start(configfile):
 
     try:
         subprocess.check_call(args)
-        write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN)
+        write("started synapse.app.homeserver(%r)" %
+              (configfile,), colour=GREEN)
     except subprocess.CalledProcessError as e:
         write(
             "error starting (exit code: %d); see above for logs" % e.returncode,
@@ -76,8 +95,16 @@ def start_worker(app, configfile, worker_configfile):
 def stop(pidfile, app):
     if os.path.exists(pidfile):
         pid = int(open(pidfile).read())
-        os.kill(pid, signal.SIGTERM)
-        write("stopped %s" % (app,), colour=GREEN)
+        try:
+            os.kill(pid, signal.SIGTERM)
+            write("stopped %s" % (app,), colour=GREEN)
+        except OSError as err:
+            if err.errno == errno.ESRCH:
+                write("%s not running" % (app,), colour=YELLOW)
+            elif err.errno == errno.EPERM:
+                abort("Cannot stop %s: Operation not permitted" % (app,))
+            else:
+                abort("Cannot stop %s: Unknown error" % (app,))
 
 
 Worker = collections.namedtuple("Worker", [
@@ -98,7 +125,7 @@ def main():
         "configfile",
         nargs="?",
         default="homeserver.yaml",
-        help="the homeserver config file, defaults to homserver.yaml",
+        help="the homeserver config file, defaults to homeserver.yaml",
     )
     parser.add_argument(
         "-w", "--worker",
@@ -157,6 +184,9 @@ def main():
         worker_configfiles.append(worker_configfile)
 
     if options.all_processes:
+        # To start the main synapse with -a you need to add a worker file
+        # with worker_app == "synapse.app.homeserver"
+        start_stop_synapse = False
         worker_configdir = options.all_processes
         if not os.path.isdir(worker_configdir):
             write(
@@ -173,10 +203,29 @@ def main():
         with open(worker_configfile) as stream:
             worker_config = yaml.load(stream)
         worker_app = worker_config["worker_app"]
-        worker_pidfile = worker_config["worker_pid_file"]
-        worker_daemonize = worker_config["worker_daemonize"]
-        assert worker_daemonize  # TODO print something more user friendly
-        worker_cache_factor = worker_config.get("synctl_cache_factor")
+        if worker_app == "synapse.app.homeserver":
+            # We need to special case all of this to pick up options that may
+            # be set in the main config file or in this worker config file.
+            worker_pidfile = (
+                worker_config.get("pid_file")
+                or pidfile
+            )
+            worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor
+            daemonize = worker_config.get("daemonize") or config.get("daemonize")
+            assert daemonize, "Main process must have daemonize set to true"
+
+            # The master process doesn't support using worker_* config.
+            for key in worker_config:
+                if key == "worker_app":  # But we allow worker_app
+                    continue
+                assert not key.startswith("worker_"), \
+                    "Main process cannot use worker_* config"
+        else:
+            worker_pidfile = worker_config["worker_pid_file"]
+            worker_daemonize = worker_config["worker_daemonize"]
+            assert worker_daemonize, "In config %r: expected '%s' to be True" % (
+                worker_configfile, "worker_daemonize")
+            worker_cache_factor = worker_config.get("synctl_cache_factor")
         workers.append(Worker(
             worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
         ))
@@ -190,10 +239,26 @@ def main():
         if start_stop_synapse:
             stop(pidfile, "synapse.app.homeserver")
 
-        # TODO: Wait for synapse to actually shutdown before starting it again
+    # Wait for synapse to actually shutdown before starting it again
+    if action == "restart":
+        running_pids = []
+        if start_stop_synapse and os.path.exists(pidfile):
+            running_pids.append(int(open(pidfile).read()))
+        for worker in workers:
+            if os.path.exists(worker.pidfile):
+                running_pids.append(int(open(worker.pidfile).read()))
+        if len(running_pids) > 0:
+            write("Waiting for process to exit before restarting...")
+            for running_pid in running_pids:
+                while pid_running(running_pid):
+                    time.sleep(0.2)
+            write("All processes exited; now restarting...")
 
     if action == "start" or action == "restart":
         if start_stop_synapse:
+            # Check if synapse is already running
+            if os.path.exists(pidfile) and pid_running(int(open(pidfile).read())):
+                abort("synapse.app.homeserver already running")
             start(configfile)
 
         for worker in workers:
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
new file mode 100644
index 0000000000..5ba7e9b416
--- /dev/null
+++ b/synapse/app/user_dir.py
@@ -0,0 +1,231 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import sys
+
+import synapse
+from synapse import events
+from synapse.app import _base
+from synapse.config._base import ConfigError
+from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
+from synapse.crypto import context_factory
+from synapse.http.server import JsonResource
+from synapse.http.site import SynapseSite
+from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
+from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.rest.client.v2_alpha import user_directory
+from synapse.server import HomeServer
+from synapse.storage.engines import create_engine
+from synapse.storage.user_directory import UserDirectoryStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.httpresourcetree import create_resource_tree
+from synapse.util.logcontext import LoggingContext, run_in_background
+from synapse.util.manhole import manhole
+from synapse.util.versionstring import get_version_string
+from twisted.internet import reactor, defer
+from twisted.web.resource import NoResource
+
+logger = logging.getLogger("synapse.app.user_dir")
+
+
+class UserDirectorySlaveStore(
+    SlavedEventStore,
+    SlavedApplicationServiceStore,
+    SlavedRegistrationStore,
+    SlavedClientIpStore,
+    UserDirectoryStore,
+    BaseSlavedStore,
+):
+    def __init__(self, db_conn, hs):
+        super(UserDirectorySlaveStore, self).__init__(db_conn, hs)
+
+        events_max = self._stream_id_gen.get_current_token()
+        curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
+            db_conn, "current_state_delta_stream",
+            entity_column="room_id",
+            stream_column="stream_id",
+            max_value=events_max,  # As we share the stream id with events token
+            limit=1000,
+        )
+        self._curr_state_delta_stream_cache = StreamChangeCache(
+            "_curr_state_delta_stream_cache", min_curr_state_delta_id,
+            prefilled_cache=curr_state_delta_prefill,
+        )
+
+        self._current_state_delta_pos = events_max
+
+    def stream_positions(self):
+        result = super(UserDirectorySlaveStore, self).stream_positions()
+        result["current_state_deltas"] = self._current_state_delta_pos
+        return result
+
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "current_state_deltas":
+            self._current_state_delta_pos = token
+            for row in rows:
+                self._curr_state_delta_stream_cache.entity_has_changed(
+                    row.room_id, token
+                )
+        return super(UserDirectorySlaveStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
+
+
+class UserDirectoryServer(HomeServer):
+    def setup(self):
+        logger.info("Setting up.")
+        self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
+        logger.info("Finished setting up.")
+
+    def _listen_http(self, listener_config):
+        port = listener_config["port"]
+        bind_addresses = listener_config["bind_addresses"]
+        site_tag = listener_config.get("tag", port)
+        resources = {}
+        for res in listener_config["resources"]:
+            for name in res["names"]:
+                if name == "metrics":
+                    resources[METRICS_PREFIX] = MetricsResource(self)
+                elif name == "client":
+                    resource = JsonResource(self, canonical_json=False)
+                    user_directory.register_servlets(self, resource)
+                    resources.update({
+                        "/_matrix/client/r0": resource,
+                        "/_matrix/client/unstable": resource,
+                        "/_matrix/client/v2_alpha": resource,
+                        "/_matrix/client/api/v1": resource,
+                    })
+
+        root_resource = create_resource_tree(resources, NoResource())
+
+        _base.listen_tcp(
+            bind_addresses,
+            port,
+            SynapseSite(
+                "synapse.access.http.%s" % (site_tag,),
+                site_tag,
+                listener_config,
+                root_resource,
+            )
+        )
+
+        logger.info("Synapse user_dir now listening on port %d", port)
+
+    def start_listening(self, listeners):
+        for listener in listeners:
+            if listener["type"] == "http":
+                self._listen_http(listener)
+            elif listener["type"] == "manhole":
+                _base.listen_tcp(
+                    listener["bind_addresses"],
+                    listener["port"],
+                    manhole(
+                        username="matrix",
+                        password="rabbithole",
+                        globals={"hs": self},
+                    )
+                )
+            else:
+                logger.warn("Unrecognized listener type: %s", listener["type"])
+
+        self.get_tcp_replication().start_replication(self)
+
+    def build_tcp_replication(self):
+        return UserDirectoryReplicationHandler(self)
+
+
+class UserDirectoryReplicationHandler(ReplicationClientHandler):
+    def __init__(self, hs):
+        super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore())
+        self.user_directory = hs.get_user_directory_handler()
+
+    def on_rdata(self, stream_name, token, rows):
+        super(UserDirectoryReplicationHandler, self).on_rdata(
+            stream_name, token, rows
+        )
+        if stream_name == "current_state_deltas":
+            run_in_background(self._notify_directory)
+
+    @defer.inlineCallbacks
+    def _notify_directory(self):
+        try:
+            yield self.user_directory.notify_new_event()
+        except Exception:
+            logger.exception("Error notifiying user directory of state update")
+
+
+def start(config_options):
+    try:
+        config = HomeServerConfig.load_config(
+            "Synapse user directory", config_options
+        )
+    except ConfigError as e:
+        sys.stderr.write("\n" + e.message + "\n")
+        sys.exit(1)
+
+    assert config.worker_app == "synapse.app.user_dir"
+
+    setup_logging(config, use_worker_options=True)
+
+    events.USE_FROZEN_DICTS = config.use_frozen_dicts
+
+    database_engine = create_engine(config.database_config)
+
+    if config.update_user_directory:
+        sys.stderr.write(
+            "\nThe update_user_directory must be disabled in the main synapse process"
+            "\nbefore they can be run in a separate worker."
+            "\nPlease add ``update_user_directory: false`` to the main config"
+            "\n"
+        )
+        sys.exit(1)
+
+    # Force the pushers to start since they will be disabled in the main config
+    config.update_user_directory = True
+
+    tls_server_context_factory = context_factory.ServerContextFactory(config)
+
+    ps = UserDirectoryServer(
+        config.server_name,
+        db_config=config.database_config,
+        tls_server_context_factory=tls_server_context_factory,
+        config=config,
+        version_string="Synapse/" + get_version_string(synapse),
+        database_engine=database_engine,
+    )
+
+    ps.setup()
+    ps.start_listening(config.worker_listeners)
+
+    def start():
+        ps.get_datastore().start_profiling()
+        ps.get_state_handler().start_caching()
+
+    reactor.callWhenRunning(start)
+
+    _base.start_worker_reactor("synapse-user-dir", config)
+
+
+if __name__ == '__main__':
+    with LoggingContext("main"):
+        start(sys.argv[1:])
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index b0106a3597..5fdb579723 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -13,12 +13,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from synapse.api.constants import EventTypes
+from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.types import GroupID, get_domain_from_id
 
 from twisted.internet import defer
 
 import logging
 import re
 
+from six import string_types
+
 logger = logging.getLogger(__name__)
 
 
@@ -80,12 +84,13 @@ class ApplicationService(object):
     # values.
     NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
 
-    def __init__(self, token, url=None, namespaces=None, hs_token=None,
+    def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None,
                  sender=None, id=None, protocols=None, rate_limited=True):
         self.token = token
         self.url = url
         self.hs_token = hs_token
         self.sender = sender
+        self.server_name = hostname
         self.namespaces = self._check_namespaces(namespaces)
         self.id = id
 
@@ -124,29 +129,41 @@ class ApplicationService(object):
                     raise ValueError(
                         "Expected bool for 'exclusive' in ns '%s'" % ns
                     )
-                if not isinstance(regex_obj.get("regex"), basestring):
+                group_id = regex_obj.get("group_id")
+                if group_id:
+                    if not isinstance(group_id, str):
+                        raise ValueError(
+                            "Expected string for 'group_id' in ns '%s'" % ns
+                        )
+                    try:
+                        GroupID.from_string(group_id)
+                    except Exception:
+                        raise ValueError(
+                            "Expected valid group ID for 'group_id' in ns '%s'" % ns
+                        )
+
+                    if get_domain_from_id(group_id) != self.server_name:
+                        raise ValueError(
+                            "Expected 'group_id' to be this host in ns '%s'" % ns
+                        )
+
+                regex = regex_obj.get("regex")
+                if isinstance(regex, string_types):
+                    regex_obj["regex"] = re.compile(regex)  # Pre-compile regex
+                else:
                     raise ValueError(
                         "Expected string for 'regex' in ns '%s'" % ns
                     )
         return namespaces
 
-    def _matches_regex(self, test_string, namespace_key, return_obj=False):
-        if not isinstance(test_string, basestring):
-            logger.error(
-                "Expected a string to test regex against, but got %s",
-                test_string
-            )
-            return False
-
+    def _matches_regex(self, test_string, namespace_key):
         for regex_obj in self.namespaces[namespace_key]:
-            if re.match(regex_obj["regex"], test_string):
-                if return_obj:
-                    return regex_obj
-                return True
-        return False
+            if regex_obj["regex"].match(test_string):
+                return regex_obj
+        return None
 
     def _is_exclusive(self, ns_key, test_string):
-        regex_obj = self._matches_regex(test_string, ns_key, return_obj=True)
+        regex_obj = self._matches_regex(test_string, ns_key)
         if regex_obj:
             return regex_obj["exclusive"]
         return False
@@ -166,7 +183,14 @@ class ApplicationService(object):
         if not store:
             defer.returnValue(False)
 
-        member_list = yield store.get_users_in_room(event.room_id)
+        does_match = yield self._matches_user_in_member_list(event.room_id, store)
+        defer.returnValue(does_match)
+
+    @cachedInlineCallbacks(num_args=1, cache_context=True)
+    def _matches_user_in_member_list(self, room_id, store, cache_context):
+        member_list = yield store.get_users_in_room(
+            room_id, on_invalidate=cache_context.invalidate
+        )
 
         # check joined member events
         for user_id in member_list:
@@ -219,10 +243,10 @@ class ApplicationService(object):
         )
 
     def is_interested_in_alias(self, alias):
-        return self._matches_regex(alias, ApplicationService.NS_ALIASES)
+        return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
 
     def is_interested_in_room(self, room_id):
-        return self._matches_regex(room_id, ApplicationService.NS_ROOMS)
+        return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
 
     def is_exclusive_user(self, user_id):
         return (
@@ -239,6 +263,31 @@ class ApplicationService(object):
     def is_exclusive_room(self, room_id):
         return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
 
+    def get_exlusive_user_regexes(self):
+        """Get the list of regexes used to determine if a user is exclusively
+        registered by the AS
+        """
+        return [
+            regex_obj["regex"]
+            for regex_obj in self.namespaces[ApplicationService.NS_USERS]
+            if regex_obj["exclusive"]
+        ]
+
+    def get_groups_for_user(self, user_id):
+        """Get the groups that this user is associated with by this AS
+
+        Args:
+            user_id (str): The ID of the user.
+
+        Returns:
+            iterable[str]: an iterable that yields group_id strings.
+        """
+        return (
+            regex_obj["group_id"]
+            for regex_obj in self.namespaces[ApplicationService.NS_USERS]
+            if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
+        )
+
     def is_rate_limited(self):
         return self.rate_limited
 
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 6893610e71..00efff1464 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -72,7 +72,8 @@ class ApplicationServiceApi(SimpleHttpClient):
         super(ApplicationServiceApi, self).__init__(hs)
         self.clock = hs.get_clock()
 
-        self.protocol_meta_cache = ResponseCache(hs, timeout_ms=HOUR_IN_MS)
+        self.protocol_meta_cache = ResponseCache(hs, "as_protocol_meta",
+                                                 timeout_ms=HOUR_IN_MS)
 
     @defer.inlineCallbacks
     def query_user(self, service, user_id):
@@ -192,9 +193,7 @@ class ApplicationServiceApi(SimpleHttpClient):
                 defer.returnValue(None)
 
         key = (service.id, protocol)
-        return self.protocol_meta_cache.get(key) or (
-            self.protocol_meta_cache.set(key, _get())
-        )
+        return self.protocol_meta_cache.wrap(key, _get)
 
     @defer.inlineCallbacks
     def push_bulk(self, service, events, txn_id=None):
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 68a9de17b8..6eddbc0828 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -51,7 +51,7 @@ components.
 from twisted.internet import defer
 
 from synapse.appservice import ApplicationServiceState
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import run_in_background
 from synapse.util.metrics import Measure
 
 import logging
@@ -106,7 +106,7 @@ class _ServiceQueuer(object):
     def enqueue(self, service, event):
         # if this service isn't being sent something
         self.queued_events.setdefault(service.id, []).append(event)
-        preserve_fn(self._send_request)(service)
+        run_in_background(self._send_request, service)
 
     @defer.inlineCallbacks
     def _send_request(self, service):
@@ -123,7 +123,7 @@ class _ServiceQueuer(object):
                 with Measure(self.clock, "servicequeuer.send"):
                     try:
                         yield self.txn_ctrl.send(service, events)
-                    except:
+                    except Exception:
                         logger.exception("AS request failed")
         finally:
             self.requests_in_flight.discard(service.id)
@@ -152,10 +152,10 @@ class _TransactionController(object):
                 if sent:
                     yield txn.complete(self.store)
                 else:
-                    preserve_fn(self._start_recoverer)(service)
-        except Exception as e:
-            logger.exception(e)
-            preserve_fn(self._start_recoverer)(service)
+                    run_in_background(self._start_recoverer, service)
+        except Exception:
+            logger.exception("Error creating appservice transaction")
+            run_in_background(self._start_recoverer, service)
 
     @defer.inlineCallbacks
     def on_recovered(self, recoverer):
@@ -176,17 +176,20 @@ class _TransactionController(object):
 
     @defer.inlineCallbacks
     def _start_recoverer(self, service):
-        yield self.store.set_appservice_state(
-            service,
-            ApplicationServiceState.DOWN
-        )
-        logger.info(
-            "Application service falling behind. Starting recoverer. AS ID %s",
-            service.id
-        )
-        recoverer = self.recoverer_fn(service, self.on_recovered)
-        self.add_recoverers([recoverer])
-        recoverer.recover()
+        try:
+            yield self.store.set_appservice_state(
+                service,
+                ApplicationServiceState.DOWN
+            )
+            logger.info(
+                "Application service falling behind. Starting recoverer. AS ID %s",
+                service.id
+            )
+            recoverer = self.recoverer_fn(service, self.on_recovered)
+            self.add_recoverers([recoverer])
+            recoverer.recover()
+        except Exception:
+            logger.exception("Error starting AS recoverer")
 
     @defer.inlineCallbacks
     def _is_service_up(self, service):
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 1ab5593c6e..b748ed2b0a 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -19,6 +19,8 @@ import os
 import yaml
 from textwrap import dedent
 
+from six import integer_types
+
 
 class ConfigError(Exception):
     pass
@@ -49,7 +51,7 @@ Missing mandatory `server_name` config option.
 class Config(object):
     @staticmethod
     def parse_size(value):
-        if isinstance(value, int) or isinstance(value, long):
+        if isinstance(value, integer_types):
             return value
         sizes = {"K": 1024, "M": 1024 * 1024}
         size = 1
@@ -61,7 +63,7 @@ class Config(object):
 
     @staticmethod
     def parse_duration(value):
-        if isinstance(value, int) or isinstance(value, long):
+        if isinstance(value, integer_types):
             return value
         second = 1000
         minute = 60 * second
@@ -82,21 +84,37 @@ class Config(object):
         return os.path.abspath(file_path) if file_path else file_path
 
     @classmethod
+    def path_exists(cls, file_path):
+        """Check if a file exists
+
+        Unlike os.path.exists, this throws an exception if there is an error
+        checking if the file exists (for example, if there is a perms error on
+        the parent dir).
+
+        Returns:
+            bool: True if the file exists; False if not.
+        """
+        try:
+            os.stat(file_path)
+            return True
+        except OSError as e:
+            if e.errno != errno.ENOENT:
+                raise e
+            return False
+
+    @classmethod
     def check_file(cls, file_path, config_name):
         if file_path is None:
             raise ConfigError(
                 "Missing config for %s."
-                " You must specify a path for the config file. You can "
-                "do this with the -c or --config-path option. "
-                "Adding --generate-config along with --server-name "
-                "<server name> will generate a config file at the given path."
                 % (config_name,)
             )
-        if not os.path.exists(file_path):
+        try:
+            os.stat(file_path)
+        except OSError as e:
             raise ConfigError(
-                "File %s config for %s doesn't exist."
-                " Try running again with --generate-config"
-                % (file_path, config_name,)
+                "Error accessing file '%s' (config for %s): %s"
+                % (file_path, config_name, e.strerror)
             )
         return cls.abspath(file_path)
 
@@ -248,7 +266,7 @@ class Config(object):
                     " -c CONFIG-FILE\""
                 )
             (config_path,) = config_files
-            if not os.path.exists(config_path):
+            if not cls.path_exists(config_path):
                 if config_args.keys_directory:
                     config_dir_path = config_args.keys_directory
                 else:
@@ -261,33 +279,33 @@ class Config(object):
                         "Must specify a server_name to a generate config for."
                         " Pass -H server.name."
                     )
-                if not os.path.exists(config_dir_path):
+                if not cls.path_exists(config_dir_path):
                     os.makedirs(config_dir_path)
-                with open(config_path, "wb") as config_file:
-                    config_bytes, config = obj.generate_config(
+                with open(config_path, "w") as config_file:
+                    config_str, config = obj.generate_config(
                         config_dir_path=config_dir_path,
                         server_name=server_name,
                         report_stats=(config_args.report_stats == "yes"),
                         is_generating_file=True
                     )
                     obj.invoke_all("generate_files", config)
-                    config_file.write(config_bytes)
-                print (
+                    config_file.write(config_str)
+                print((
                     "A config file has been generated in %r for server name"
                     " %r with corresponding SSL keys and self-signed"
                     " certificates. Please review this file and customise it"
                     " to your needs."
-                ) % (config_path, server_name)
-                print (
+                ) % (config_path, server_name))
+                print(
                     "If this server name is incorrect, you will need to"
                     " regenerate the SSL certificates"
                 )
                 return
             else:
-                print (
+                print((
                     "Config file %r already exists. Generating any missing key"
                     " files."
-                ) % (config_path,)
+                ) % (config_path,))
                 generate_keys = True
 
         parser = argparse.ArgumentParser(
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 82c50b8240..277305e184 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -17,10 +17,12 @@ from ._base import Config, ConfigError
 from synapse.appservice import ApplicationService
 from synapse.types import UserID
 
-import urllib
 import yaml
 import logging
 
+from six import string_types
+from six.moves.urllib import parse as urlparse
+
 logger = logging.getLogger(__name__)
 
 
@@ -89,21 +91,21 @@ def _load_appservice(hostname, as_info, config_filename):
         "id", "as_token", "hs_token", "sender_localpart"
     ]
     for field in required_string_fields:
-        if not isinstance(as_info.get(field), basestring):
+        if not isinstance(as_info.get(field), string_types):
             raise KeyError("Required string field: '%s' (%s)" % (
                 field, config_filename,
             ))
 
     # 'url' must either be a string or explicitly null, not missing
     # to avoid accidentally turning off push for ASes.
-    if (not isinstance(as_info.get("url"), basestring) and
+    if (not isinstance(as_info.get("url"), string_types) and
             as_info.get("url", "") is not None):
         raise KeyError(
             "Required string field or explicit null: 'url' (%s)" % (config_filename,)
         )
 
     localpart = as_info["sender_localpart"]
-    if urllib.quote(localpart) != localpart:
+    if urlparse.quote(localpart) != localpart:
         raise ValueError(
             "sender_localpart needs characters which are not URL encoded."
         )
@@ -128,7 +130,7 @@ def _load_appservice(hostname, as_info, config_filename):
                         "Expected namespace entry in %s to be an object,"
                         " but got %s", ns, regex_obj
                     )
-                if not isinstance(regex_obj.get("regex"), basestring):
+                if not isinstance(regex_obj.get("regex"), string_types):
                     raise ValueError(
                         "Missing/bad type 'regex' key in %s", regex_obj
                     )
@@ -154,6 +156,7 @@ def _load_appservice(hostname, as_info, config_filename):
         )
     return ApplicationService(
         token=as_info["as_token"],
+        hostname=hostname,
         url=as_info["url"],
         namespaces=as_info["namespaces"],
         hs_token=as_info["hs_token"],
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 938f6f25f8..8109e5f95e 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -41,7 +41,7 @@ class CasConfig(Config):
         #cas_config:
         #   enabled: true
         #   server_url: "https://cas-server.com"
-        #   service_url: "https://homesever.domain.com:8448"
+        #   service_url: "https://homeserver.domain.com:8448"
         #   #required_attributes:
         #   #    name: value
         """
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 0030b5db1e..fe156b6930 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -71,6 +71,15 @@ class EmailConfig(Config):
             self.email_riot_base_url = email_config.get(
                 "riot_base_url", None
             )
+            self.email_smtp_user = email_config.get(
+                "smtp_user", None
+            )
+            self.email_smtp_pass = email_config.get(
+                "smtp_pass", None
+            )
+            self.require_transport_security = email_config.get(
+                "require_transport_security", False
+            )
             if "app_name" in email_config:
                 self.email_app_name = email_config["app_name"]
             else:
@@ -91,10 +100,17 @@ class EmailConfig(Config):
         # Defining a custom URL for Riot is only needed if email notifications
         # should contain links to a self-hosted installation of Riot; when set
         # the "app_name" setting is ignored.
+        #
+        # If your SMTP server requires authentication, the optional smtp_user &
+        # smtp_pass variables should be used
+        #
         #email:
         #   enable_notifs: false
         #   smtp_host: "localhost"
         #   smtp_port: 25
+        #   smtp_user: "exampleusername"
+        #   smtp_pass: "examplepassword"
+        #   require_transport_security: False
         #   notif_from: "Your Friendly %(app)s Home Server <noreply@example.com>"
         #   app_name: Matrix
         #   template_dir: res/templates
diff --git a/synapse/config/groups.py b/synapse/config/groups.py
new file mode 100644
index 0000000000..997fa2881f
--- /dev/null
+++ b/synapse/config/groups.py
@@ -0,0 +1,32 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config
+
+
+class GroupsConfig(Config):
+    def read_config(self, config):
+        self.enable_group_creation = config.get("enable_group_creation", False)
+        self.group_creation_prefix = config.get("group_creation_prefix", "")
+
+    def default_config(self, **kwargs):
+        return """\
+        # Whether to allow non server admins to create groups on this server
+        enable_group_creation: false
+
+        # If enabled, non server admins can only create groups with local parts
+        # starting with this prefix
+        # group_creation_prefix: "unofficial/"
+        """
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 0f890fc04a..bf19cfee29 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -33,6 +33,10 @@ from .jwt import JWTConfig
 from .password_auth_providers import PasswordAuthProviderConfig
 from .emailconfig import EmailConfig
 from .workers import WorkerConfig
+from .push import PushConfig
+from .spam_checker import SpamCheckerConfig
+from .groups import GroupsConfig
+from .user_directory import UserDirectoryConfig
 
 
 class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
@@ -40,7 +44,8 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
                        VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
                        AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
                        JWTConfig, PasswordConfig, EmailConfig,
-                       WorkerConfig, PasswordAuthProviderConfig,):
+                       WorkerConfig, PasswordAuthProviderConfig, PushConfig,
+                       SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,):
     pass
 
 
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 6ee643793e..4b8fc063d0 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -118,10 +118,9 @@ class KeyConfig(Config):
         signing_keys = self.read_file(signing_key_path, "signing_key")
         try:
             return read_signing_keys(signing_keys.splitlines(True))
-        except Exception:
+        except Exception as e:
             raise ConfigError(
-                "Error reading signing_key."
-                " Try running again with --generate-config"
+                "Error reading signing_key: %s" % (str(e))
             )
 
     def read_old_signing_keys(self, old_signing_keys):
@@ -141,7 +140,8 @@ class KeyConfig(Config):
 
     def generate_files(self, config):
         signing_key_path = config["signing_key_path"]
-        if not os.path.exists(signing_key_path):
+
+        if not self.path_exists(signing_key_path):
             with open(signing_key_path, "w") as signing_key_file:
                 key_id = "a_" + random_string(4)
                 write_signing_keys(
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 77ded0ad25..6a7228dc2f 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -28,34 +28,35 @@ DEFAULT_LOG_CONFIG = Template("""
 version: 1
 
 formatters:
-  precise:
-   format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\
-- %(message)s'
+    precise:
+        format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \
+%(request)s - %(message)s'
 
 filters:
-  context:
-    (): synapse.util.logcontext.LoggingContextFilter
-    request: ""
+    context:
+        (): synapse.util.logcontext.LoggingContextFilter
+        request: ""
 
 handlers:
-  file:
-    class: logging.handlers.RotatingFileHandler
-    formatter: precise
-    filename: ${log_file}
-    maxBytes: 104857600
-    backupCount: 10
-    filters: [context]
-    level: INFO
-  console:
-    class: logging.StreamHandler
-    formatter: precise
-    filters: [context]
+    file:
+        class: logging.handlers.RotatingFileHandler
+        formatter: precise
+        filename: ${log_file}
+        maxBytes: 104857600
+        backupCount: 10
+        filters: [context]
+    console:
+        class: logging.StreamHandler
+        formatter: precise
+        filters: [context]
 
 loggers:
     synapse:
         level: INFO
 
     synapse.storage.SQL:
+        # beware: increasing this to DEBUG will make synapse log sensitive
+        # information such as access tokens.
         level: INFO
 
 root:
@@ -68,21 +69,15 @@ class LoggingConfig(Config):
 
     def read_config(self, config):
         self.verbosity = config.get("verbose", 0)
+        self.no_redirect_stdio = config.get("no_redirect_stdio", False)
         self.log_config = self.abspath(config.get("log_config"))
         self.log_file = self.abspath(config.get("log_file"))
 
     def default_config(self, config_dir_path, server_name, **kwargs):
-        log_file = self.abspath("homeserver.log")
         log_config = self.abspath(
             os.path.join(config_dir_path, server_name + ".log.config")
         )
         return """
-        # Logging verbosity level.
-        verbose: 0
-
-        # File to write logging to
-        log_file: "%(log_file)s"
-
         # A yaml python logging config file
         log_config: "%(log_config)s"
         """ % locals()
@@ -90,6 +85,8 @@ class LoggingConfig(Config):
     def read_arguments(self, args):
         if args.verbose is not None:
             self.verbosity = args.verbose
+        if args.no_redirect_stdio is not None:
+            self.no_redirect_stdio = args.no_redirect_stdio
         if args.log_config is not None:
             self.log_config = args.log_config
         if args.log_file is not None:
@@ -99,48 +96,68 @@ class LoggingConfig(Config):
         logging_group = parser.add_argument_group("logging")
         logging_group.add_argument(
             '-v', '--verbose', dest="verbose", action='count',
-            help="The verbosity level."
+            help="The verbosity level. Specify multiple times to increase "
+            "verbosity. (Ignored if --log-config is specified.)"
         )
         logging_group.add_argument(
             '-f', '--log-file', dest="log_file",
-            help="File to log to."
+            help="File to log to. (Ignored if --log-config is specified.)"
         )
         logging_group.add_argument(
             '--log-config', dest="log_config", default=None,
             help="Python logging config file"
         )
+        logging_group.add_argument(
+            '-n', '--no-redirect-stdio',
+            action='store_true', default=None,
+            help="Do not redirect stdout/stderr to the log"
+        )
 
     def generate_files(self, config):
         log_config = config.get("log_config")
         if log_config and not os.path.exists(log_config):
-            with open(log_config, "wb") as log_config_file:
+            log_file = self.abspath("homeserver.log")
+            with open(log_config, "w") as log_config_file:
                 log_config_file.write(
-                    DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
+                    DEFAULT_LOG_CONFIG.substitute(log_file=log_file)
                 )
 
-    def setup_logging(self):
-        setup_logging(self.log_config, self.log_file, self.verbosity)
 
+def setup_logging(config, use_worker_options=False):
+    """ Set up python logging
+
+    Args:
+        config (LoggingConfig | synapse.config.workers.WorkerConfig):
+            configuration data
+
+        use_worker_options (bool): True to use 'worker_log_config' and
+            'worker_log_file' options instead of 'log_config' and 'log_file'.
+    """
+    log_config = (config.worker_log_config if use_worker_options
+                  else config.log_config)
+    log_file = (config.worker_log_file if use_worker_options
+                else config.log_file)
 
-def setup_logging(log_config=None, log_file=None, verbosity=None):
     log_format = (
         "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
         " - %(message)s"
     )
-    if log_config is None:
 
+    if log_config is None:
+        # We don't have a logfile, so fall back to the 'verbosity' param from
+        # the config or cmdline. (Note that we generate a log config for new
+        # installs, so this will be an unusual case)
         level = logging.INFO
         level_for_storage = logging.INFO
-        if verbosity:
+        if config.verbosity:
             level = logging.DEBUG
-            if verbosity > 1:
+            if config.verbosity > 1:
                 level_for_storage = logging.DEBUG
 
-        # FIXME: we need a logging.WARN for a -q quiet option
         logger = logging.getLogger('')
         logger.setLevel(level)
 
-        logging.getLogger('synapse.storage').setLevel(level_for_storage)
+        logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage)
 
         formatter = logging.Formatter(log_format)
         if log_file:
@@ -153,24 +170,37 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
                 logger.info("Closing log file due to SIGHUP")
                 handler.doRollover()
                 logger.info("Opened new log file due to SIGHUP")
-
-            # TODO(paul): obviously this is a terrible mechanism for
-            #   stealing SIGHUP, because it means no other part of synapse
-            #   can use it instead. If we want to catch SIGHUP anywhere
-            #   else as well, I'd suggest we find a nicer way to broadcast
-            #   it around.
-            if getattr(signal, "SIGHUP"):
-                signal.signal(signal.SIGHUP, sighup)
         else:
             handler = logging.StreamHandler()
+
+            def sighup(signum, stack):
+                pass
+
         handler.setFormatter(formatter)
 
         handler.addFilter(LoggingContextFilter(request=""))
 
         logger.addHandler(handler)
     else:
-        with open(log_config, 'r') as f:
-            logging.config.dictConfig(yaml.load(f))
+        def load_log_config():
+            with open(log_config, 'r') as f:
+                logging.config.dictConfig(yaml.load(f))
+
+        def sighup(signum, stack):
+            # it might be better to use a file watcher or something for this.
+            logging.info("Reloading log config from %s due to SIGHUP",
+                         log_config)
+            load_log_config()
+
+        load_log_config()
+
+    # TODO(paul): obviously this is a terrible mechanism for
+    #   stealing SIGHUP, because it means no other part of synapse
+    #   can use it instead. If we want to catch SIGHUP anywhere
+    #   else as well, I'd suggest we find a nicer way to broadcast
+    #   it around.
+    if getattr(signal, "SIGHUP"):
+        signal.signal(signal.SIGHUP, sighup)
 
     # It's critical to point twisted's internal logging somewhere, otherwise it
     # stacks up and leaks kup to 64K object;
@@ -183,4 +213,7 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
     #
     # However this may not be too much of a problem if we are just writing to a file.
     observer = STDLibLogObserver()
-    globalLogBeginner.beginLoggingTo([observer])
+    globalLogBeginner.beginLoggingTo(
+        [observer],
+        redirectStandardIO=not config.no_redirect_stdio,
+    )
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index 83762d089a..6602c5b4c7 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -13,44 +13,41 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import Config, ConfigError
+from ._base import Config
 
-import importlib
+from synapse.util.module_loader import load_module
+
+LDAP_PROVIDER = 'ldap_auth_provider.LdapAuthProvider'
 
 
 class PasswordAuthProviderConfig(Config):
     def read_config(self, config):
         self.password_providers = []
+        providers = []
 
         # We want to be backwards compatible with the old `ldap_config`
         # param.
         ldap_config = config.get("ldap_config", {})
-        self.ldap_enabled = ldap_config.get("enabled", False)
-        if self.ldap_enabled:
-            from ldap_auth_provider import LdapAuthProvider
-            parsed_config = LdapAuthProvider.parse_config(ldap_config)
-            self.password_providers.append((LdapAuthProvider, parsed_config))
+        if ldap_config.get("enabled", False):
+            providers.append({
+                'module': LDAP_PROVIDER,
+                'config': ldap_config,
+            })
 
-        providers = config.get("password_providers", [])
+        providers.extend(config.get("password_providers", []))
         for provider in providers:
+            mod_name = provider['module']
+
             # This is for backwards compat when the ldap auth provider resided
             # in this package.
-            if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider":
-                from ldap_auth_provider import LdapAuthProvider
-                provider_class = LdapAuthProvider
-            else:
-                # We need to import the module, and then pick the class out of
-                # that, so we split based on the last dot.
-                module, clz = provider['module'].rsplit(".", 1)
-                module = importlib.import_module(module)
-                provider_class = getattr(module, clz)
+            if mod_name == "synapse.util.ldap_auth_provider.LdapAuthProvider":
+                mod_name = LDAP_PROVIDER
+
+            (provider_class, provider_config) = load_module({
+                "module": mod_name,
+                "config": provider['config'],
+            })
 
-            try:
-                provider_config = provider_class.parse_config(provider["config"])
-            except Exception as e:
-                raise ConfigError(
-                    "Failed to parse config for %r: %r" % (provider['module'], e)
-                )
             self.password_providers.append((provider_class, provider_config))
 
     def default_config(self, **kwargs):
diff --git a/synapse/config/push.py b/synapse/config/push.py
new file mode 100644
index 0000000000..b7e0d46afa
--- /dev/null
+++ b/synapse/config/push.py
@@ -0,0 +1,61 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config
+
+
+class PushConfig(Config):
+    def read_config(self, config):
+        push_config = config.get("push", {})
+        self.push_include_content = push_config.get("include_content", True)
+
+        # There was a a 'redact_content' setting but mistakenly read from the
+        # 'email'section'. Check for the flag in the 'push' section, and log,
+        # but do not honour it to avoid nasty surprises when people upgrade.
+        if push_config.get("redact_content") is not None:
+            print(
+                "The push.redact_content content option has never worked. "
+                "Please set push.include_content if you want this behaviour"
+            )
+
+        # Now check for the one in the 'email' section and honour it,
+        # with a warning.
+        push_config = config.get("email", {})
+        redact_content = push_config.get("redact_content")
+        if redact_content is not None:
+            print(
+                "The 'email.redact_content' option is deprecated: "
+                "please set push.include_content instead"
+            )
+            self.push_include_content = not redact_content
+
+    def default_config(self, config_dir_path, server_name, **kwargs):
+        return """
+        # Clients requesting push notifications can either have the body of
+        # the message sent in the notification poke along with other details
+        # like the sender, or just the event ID and room ID (`event_id_only`).
+        # If clients choose the former, this option controls whether the
+        # notification request includes the content of the event (other details
+        # like the sender are still included). For `event_id_only` push, it
+        # has no effect.
+
+        # For modern android devices the notification content will still appear
+        # because it is loaded by the app. iPhone, however will send a
+        # notification saying only that a message arrived and who it came from.
+        #
+        #push:
+        #   include_content: true
+        """
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 87e500c97a..c5384b3ad4 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -31,6 +31,8 @@ class RegistrationConfig(Config):
                 strtobool(str(config["disable_registration"]))
             )
 
+        self.registrations_require_3pid = config.get("registrations_require_3pid", [])
+        self.allowed_local_3pids = config.get("allowed_local_3pids", [])
         self.registration_shared_secret = config.get("registration_shared_secret")
 
         self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
@@ -41,6 +43,8 @@ class RegistrationConfig(Config):
             self.allow_guest_access and config.get("invite_3pid_guest", False)
         )
 
+        self.auto_join_rooms = config.get("auto_join_rooms", [])
+
     def default_config(self, **kwargs):
         registration_shared_secret = random_string_with_symbols(50)
 
@@ -50,13 +54,32 @@ class RegistrationConfig(Config):
         # Enable registration for new users.
         enable_registration: False
 
+        # The user must provide all of the below types of 3PID when registering.
+        #
+        # registrations_require_3pid:
+        #     - email
+        #     - msisdn
+
+        # Mandate that users are only allowed to associate certain formats of
+        # 3PIDs with accounts on this server.
+        #
+        # allowed_local_3pids:
+        #     - medium: email
+        #       pattern: ".*@matrix\\.org"
+        #     - medium: email
+        #       pattern: ".*@vector\\.im"
+        #     - medium: msisdn
+        #       pattern: "\\+44"
+
         # If set, allows registration by anyone who also has the shared
         # secret, even if registration is otherwise disabled.
         registration_shared_secret: "%(registration_shared_secret)s"
 
         # Set the number of bcrypt rounds used to generate password hash.
         # Larger numbers increase the work factor needed to generate the hash.
-        # The default number of rounds is 12.
+        # The default number is 12 (which equates to 2^12 rounds).
+        # N.B. that increasing this will exponentially increase the time required
+        # to register or login - e.g. 24 => 2^24 rounds which will take >20 mins.
         bcrypt_rounds: 12
 
         # Allows users to register as guests without a password/email/etc, and
@@ -69,6 +92,12 @@ class RegistrationConfig(Config):
         trusted_third_party_id_servers:
             - matrix.org
             - vector.im
+            - riot.im
+
+        # Users who register on this homeserver will automatically be joined
+        # to these rooms
+        #auto_join_rooms:
+        #    - "#example:example.com"
         """ % locals()
 
     def add_arguments(self, parser):
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 2c6f57168e..25ea77738a 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -16,6 +16,8 @@
 from ._base import Config, ConfigError
 from collections import namedtuple
 
+from synapse.util.module_loader import load_module
+
 
 MISSING_NETADDR = (
     "Missing netaddr library. This is required for URL preview API."
@@ -36,6 +38,14 @@ ThumbnailRequirement = namedtuple(
     "ThumbnailRequirement", ["width", "height", "method", "media_type"]
 )
 
+MediaStorageProviderConfig = namedtuple(
+    "MediaStorageProviderConfig", (
+        "store_local",  # Whether to store newly uploaded local files
+        "store_remote",  # Whether to store newly downloaded remote files
+        "store_synchronous",  # Whether to wait for successful storage for local uploads
+    ),
+)
+
 
 def parse_thumbnail_requirements(thumbnail_sizes):
     """ Takes a list of dictionaries with "width", "height", and "method" keys
@@ -70,7 +80,64 @@ class ContentRepositoryConfig(Config):
         self.max_upload_size = self.parse_size(config["max_upload_size"])
         self.max_image_pixels = self.parse_size(config["max_image_pixels"])
         self.max_spider_size = self.parse_size(config["max_spider_size"])
+
         self.media_store_path = self.ensure_directory(config["media_store_path"])
+
+        backup_media_store_path = config.get("backup_media_store_path")
+
+        synchronous_backup_media_store = config.get(
+            "synchronous_backup_media_store", False
+        )
+
+        storage_providers = config.get("media_storage_providers", [])
+
+        if backup_media_store_path:
+            if storage_providers:
+                raise ConfigError(
+                    "Cannot use both 'backup_media_store_path' and 'storage_providers'"
+                )
+
+            storage_providers = [{
+                "module": "file_system",
+                "store_local": True,
+                "store_synchronous": synchronous_backup_media_store,
+                "store_remote": True,
+                "config": {
+                    "directory": backup_media_store_path,
+                }
+            }]
+
+        # This is a list of config that can be used to create the storage
+        # providers. The entries are tuples of (Class, class_config,
+        # MediaStorageProviderConfig), where Class is the class of the provider,
+        # the class_config the config to pass to it, and
+        # MediaStorageProviderConfig are options for StorageProviderWrapper.
+        #
+        # We don't create the storage providers here as not all workers need
+        # them to be started.
+        self.media_storage_providers = []
+
+        for provider_config in storage_providers:
+            # We special case the module "file_system" so as not to need to
+            # expose FileStorageProviderBackend
+            if provider_config["module"] == "file_system":
+                provider_config["module"] = (
+                    "synapse.rest.media.v1.storage_provider"
+                    ".FileStorageProviderBackend"
+                )
+
+            provider_class, parsed_config = load_module(provider_config)
+
+            wrapper_config = MediaStorageProviderConfig(
+                provider_config.get("store_local", False),
+                provider_config.get("store_remote", False),
+                provider_config.get("store_synchronous", False),
+            )
+
+            self.media_storage_providers.append(
+                (provider_class, parsed_config, wrapper_config,)
+            )
+
         self.uploads_path = self.ensure_directory(config["uploads_path"])
         self.dynamic_thumbnails = config["dynamic_thumbnails"]
         self.thumbnail_requirements = parse_thumbnail_requirements(
@@ -115,6 +182,20 @@ class ContentRepositoryConfig(Config):
         # Directory where uploaded images and attachments are stored.
         media_store_path: "%(media_store)s"
 
+        # Media storage providers allow media to be stored in different
+        # locations.
+        # media_storage_providers:
+        # - module: file_system
+        #   # Whether to write new local files.
+        #   store_local: false
+        #   # Whether to write new remote media
+        #   store_remote: false
+        #   # Whether to block upload requests waiting for write to this
+        #   # provider to complete
+        #   store_synchronous: false
+        #   config:
+        #     directory: /mnt/some/other/directory
+
         # Directory where in-progress uploads are stored.
         uploads_path: "%(uploads_path)s"
 
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 1f9999d57a..8f0b6d1f28 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -29,12 +30,42 @@ class ServerConfig(Config):
         self.user_agent_suffix = config.get("user_agent_suffix")
         self.use_frozen_dicts = config.get("use_frozen_dicts", False)
         self.public_baseurl = config.get("public_baseurl")
+        self.cpu_affinity = config.get("cpu_affinity")
 
         # Whether to send federation traffic out in this process. This only
         # applies to some federation traffic, and so shouldn't be used to
         # "disable" federation
         self.send_federation = config.get("send_federation", True)
 
+        # Whether to update the user directory or not. This should be set to
+        # false only if we are updating the user directory in a worker
+        self.update_user_directory = config.get("update_user_directory", True)
+
+        # whether to enable the media repository endpoints. This should be set
+        # to false if the media repository is running as a separate endpoint;
+        # doing so ensures that we will not run cache cleanup jobs on the
+        # master, potentially causing inconsistency.
+        self.enable_media_repo = config.get("enable_media_repo", True)
+
+        self.filter_timeline_limit = config.get("filter_timeline_limit", -1)
+
+        # Whether we should block invites sent to users on this server
+        # (other than those sent by local server admins)
+        self.block_non_admin_invites = config.get(
+            "block_non_admin_invites", False,
+        )
+
+        # FIXME: federation_domain_whitelist needs sytests
+        self.federation_domain_whitelist = None
+        federation_domain_whitelist = config.get(
+            "federation_domain_whitelist", None
+        )
+        # turn the whitelist into a hash for speed of lookup
+        if federation_domain_whitelist is not None:
+            self.federation_domain_whitelist = {}
+            for domain in federation_domain_whitelist:
+                self.federation_domain_whitelist[domain] = True
+
         if self.public_baseurl is not None:
             if self.public_baseurl[-1] != '/':
                 self.public_baseurl += '/'
@@ -141,9 +172,36 @@ class ServerConfig(Config):
         # When running as a daemon, the file to store the pid in
         pid_file: %(pid_file)s
 
+        # CPU affinity mask. Setting this restricts the CPUs on which the
+        # process will be scheduled. It is represented as a bitmask, with the
+        # lowest order bit corresponding to the first logical CPU and the
+        # highest order bit corresponding to the last logical CPU. Not all CPUs
+        # may exist on a given system but a mask may specify more CPUs than are
+        # present.
+        #
+        # For example:
+        #    0x00000001  is processor #0,
+        #    0x00000003  is processors #0 and #1,
+        #    0xFFFFFFFF  is all processors (#0 through #31).
+        #
+        # Pinning a Python process to a single CPU is desirable, because Python
+        # is inherently single-threaded due to the GIL, and can suffer a
+        # 30-40%% slowdown due to cache blow-out and thread context switching
+        # if the scheduler happens to schedule the underlying threads across
+        # different cores. See
+        # https://www.mirantis.com/blog/improve-performance-python-programs-restricting-single-cpu/.
+        #
+        # cpu_affinity: 0xFFFFFFFF
+
         # Whether to serve a web client from the HTTP/HTTPS root resource.
         web_client: True
 
+        # The root directory to server for the above web client.
+        # If left undefined, synapse will serve the matrix-angular-sdk web client.
+        # Make sure matrix-angular-sdk is installed with pip if web_client is True
+        # and web_client_location is undefined
+        # web_client_location: "/path/to/web/root"
+
         # The public-facing base URL for the client API (not including _matrix/...)
         # public_baseurl: https://example.com:8448/
 
@@ -155,6 +213,25 @@ class ServerConfig(Config):
         # The GC threshold parameters to pass to `gc.set_threshold`, if defined
         # gc_thresholds: [700, 10, 10]
 
+        # Set the limit on the returned events in the timeline in the get
+        # and sync operations. The default value is -1, means no upper limit.
+        # filter_timeline_limit: 5000
+
+        # Whether room invites to users on this server should be blocked
+        # (except those sent by local server admins). The default is False.
+        # block_non_admin_invites: True
+
+        # Restrict federation to the following whitelist of domains.
+        # N.B. we recommend also firewalling your federation listener to limit
+        # inbound federation traffic as early as possible, rather than relying
+        # purely on this application-layer restriction.  If not specified, the
+        # default is to whitelist everything.
+        #
+        # federation_domain_whitelist:
+        #  - lon.example.com
+        #  - nyc.example.com
+        #  - syd.example.com
+
         # List of ports that Synapse should listen on, their purpose and their
         # configuration.
         listeners:
@@ -165,13 +242,12 @@ class ServerConfig(Config):
             port: %(bind_port)s
 
             # Local addresses to listen on.
-            # This will listen on all IPv4 addresses by default.
+            # On Linux and Mac OS, `::` will listen on all IPv4 and IPv6
+            # addresses by default. For most other OSes, this will only listen
+            # on IPv6.
             bind_addresses:
+              - '::'
               - '0.0.0.0'
-              # Uncomment to listen on all IPv6 interfaces
-              # N.B: On at least Linux this will also listen on all IPv4
-              # addresses, so you will need to comment out the line above.
-              # - '::'
 
             # This is a 'http' listener, allows us to specify 'resources'.
             type: http
@@ -198,11 +274,18 @@ class ServerConfig(Config):
               - names: [federation]  # Federation APIs
                 compress: false
 
+            # optional list of additional endpoints which can be loaded via
+            # dynamic modules
+            # additional_resources:
+            #   "/_matrix/my/custom/endpoint":
+            #     module: my_module.CustomRequestHandler
+            #     config: {}
+
           # Unsecure HTTP listener,
           # For when matrix traffic passes through loadbalancer that unwraps TLS.
           - port: %(unsecure_port)s
             tls: false
-            bind_addresses: ['0.0.0.0']
+            bind_addresses: ['::', '0.0.0.0']
             type: http
 
             x_forwarded: false
@@ -216,7 +299,7 @@ class ServerConfig(Config):
           # Turn on the twisted ssh manhole service on localhost on the given
           # port.
           # - port: 9000
-          #   bind_address: 127.0.0.1
+          #   bind_addresses: ['::1', '127.0.0.1']
           #   type: manhole
         """ % locals()
 
@@ -254,7 +337,7 @@ def read_gc_thresholds(thresholds):
         return (
             int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
         )
-    except:
+    except Exception:
         raise ConfigError(
             "Value of `gc_threshold` must be a list of three integers if set"
         )
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
new file mode 100644
index 0000000000..3fec42bdb0
--- /dev/null
+++ b/synapse/config/spam_checker.py
@@ -0,0 +1,35 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.module_loader import load_module
+
+from ._base import Config
+
+
+class SpamCheckerConfig(Config):
+    def read_config(self, config):
+        self.spam_checker = None
+
+        provider = config.get("spam_checker", None)
+        if provider is not None:
+            self.spam_checker = load_module(provider)
+
+    def default_config(self, **kwargs):
+        return """\
+        # spam_checker:
+        #     module: "my_custom_project.SuperSpamChecker"
+        #     config:
+        #         example_option: 'things'
+        """
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index e081840a83..b66154bc7c 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -96,7 +96,7 @@ class TlsConfig(Config):
         # certificates returned by this server match one of the fingerprints.
         #
         # Synapse automatically adds the fingerprint of its own certificate
-        # to the list. So if federation traffic is handle directly by synapse
+        # to the list. So if federation traffic is handled directly by synapse
         # then no modification to the list is required.
         #
         # If synapse is run behind a load balancer that handles the TLS then it
@@ -109,6 +109,12 @@ class TlsConfig(Config):
         # key. It may be necessary to publish the fingerprints of a new
         # certificate and wait until the "valid_until_ts" of the previous key
         # responses have passed before deploying it.
+        #
+        # You can calculate a fingerprint from a given TLS listener via:
+        # openssl s_client -connect $host:$port < /dev/null 2> /dev/null |
+        #   openssl x509 -outform DER | openssl sha256 -binary | base64 | tr -d '='
+        # or by checking matrix.org/federationtester/api/report?server_name=$host
+        #
         tls_fingerprints: []
         # tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
         """ % locals()
@@ -126,8 +132,8 @@ class TlsConfig(Config):
         tls_private_key_path = config["tls_private_key_path"]
         tls_dh_params_path = config["tls_dh_params_path"]
 
-        if not os.path.exists(tls_private_key_path):
-            with open(tls_private_key_path, "w") as private_key_file:
+        if not self.path_exists(tls_private_key_path):
+            with open(tls_private_key_path, "wb") as private_key_file:
                 tls_private_key = crypto.PKey()
                 tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
                 private_key_pem = crypto.dump_privatekey(
@@ -141,8 +147,8 @@ class TlsConfig(Config):
                     crypto.FILETYPE_PEM, private_key_pem
                 )
 
-        if not os.path.exists(tls_certificate_path):
-            with open(tls_certificate_path, "w") as certificate_file:
+        if not self.path_exists(tls_certificate_path):
+            with open(tls_certificate_path, "wb") as certificate_file:
                 cert = crypto.X509()
                 subject = cert.get_subject()
                 subject.CN = config["server_name"]
@@ -159,7 +165,7 @@ class TlsConfig(Config):
 
                 certificate_file.write(cert_pem)
 
-        if not os.path.exists(tls_dh_params_path):
+        if not self.path_exists(tls_dh_params_path):
             if GENERATE_DH_PARAMS:
                 subprocess.check_call([
                     "openssl", "dhparam",
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
new file mode 100644
index 0000000000..38e8947843
--- /dev/null
+++ b/synapse/config/user_directory.py
@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config
+
+
+class UserDirectoryConfig(Config):
+    """User Directory Configuration
+    Configuration for the behaviour of the /user_directory API
+    """
+
+    def read_config(self, config):
+        self.user_directory_search_all_users = False
+        user_directory_config = config.get("user_directory", None)
+        if user_directory_config:
+            self.user_directory_search_all_users = (
+                user_directory_config.get("search_all_users", False)
+            )
+
+    def default_config(self, config_dir_path, server_name, **kwargs):
+        return """
+        # User Directory configuration
+        #
+        # 'search_all_users' defines whether to search all users visible to your HS
+        # when searching the user directory, rather than limiting to users visible
+        # in public rooms.  Defaults to false.  If you set it True, you'll have to run
+        # UPDATE user_directory_stream_pos SET stream_id = NULL;
+        # on your database to tell it to rebuild the user_directory search indexes.
+        #
+        #user_directory:
+        #   search_all_users: false
+        """
diff --git a/synapse/config/voip.py b/synapse/config/voip.py
index eeb693027b..3a4e16fa96 100644
--- a/synapse/config/voip.py
+++ b/synapse/config/voip.py
@@ -23,6 +23,7 @@ class VoipConfig(Config):
         self.turn_username = config.get("turn_username")
         self.turn_password = config.get("turn_password")
         self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"])
+        self.turn_allow_guests = config.get("turn_allow_guests", True)
 
     def default_config(self, **kwargs):
         return """\
@@ -41,4 +42,11 @@ class VoipConfig(Config):
 
         # How long generated TURN credentials last
         turn_user_lifetime: "1h"
+
+        # Whether guests should be allowed to use the TURN server.
+        # This defaults to True, otherwise VoIP will be unreliable for guests.
+        # However, it does introduce a slight security risk as it allows users to
+        # connect to arbitrary endpoints without having first signed up for a
+        # valid account (e.g. by passing a CAPTCHA).
+        turn_allow_guests: True
         """
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index b165c67ee7..80baf0ce0e 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -23,12 +23,30 @@ class WorkerConfig(Config):
 
     def read_config(self, config):
         self.worker_app = config.get("worker_app")
+
+        # Canonicalise worker_app so that master always has None
+        if self.worker_app == "synapse.app.homeserver":
+            self.worker_app = None
+
         self.worker_listeners = config.get("worker_listeners")
         self.worker_daemonize = config.get("worker_daemonize")
         self.worker_pid_file = config.get("worker_pid_file")
         self.worker_log_file = config.get("worker_log_file")
         self.worker_log_config = config.get("worker_log_config")
-        self.worker_replication_url = config.get("worker_replication_url")
+
+        # The host used to connect to the main synapse
+        self.worker_replication_host = config.get("worker_replication_host", None)
+
+        # The port on the main synapse for TCP replication
+        self.worker_replication_port = config.get("worker_replication_port", None)
+
+        # The port on the main synapse for HTTP replication endpoint
+        self.worker_replication_http_port = config.get("worker_replication_http_port")
+
+        self.worker_name = config.get("worker_name", self.worker_app)
+
+        self.worker_main_http_uri = config.get("worker_main_http_uri", None)
+        self.worker_cpu_affinity = config.get("worker_cpu_affinity")
 
         if self.worker_listeners:
             for listener in self.worker_listeners:
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index aad4752fe7..0397f73ab4 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -13,8 +13,8 @@
 # limitations under the License.
 
 from twisted.internet import ssl
-from OpenSSL import SSL
-from twisted.internet._sslverify import _OpenSSLECCurve, _defaultCurveName
+from OpenSSL import SSL, crypto
+from twisted.internet._sslverify import _defaultCurveName
 
 import logging
 
@@ -32,9 +32,10 @@ class ServerContextFactory(ssl.ContextFactory):
     @staticmethod
     def configure_context(context, config):
         try:
-            _ecCurve = _OpenSSLECCurve(_defaultCurveName)
-            _ecCurve.addECKeyToContext(context)
-        except:
+            _ecCurve = crypto.get_elliptic_curve(_defaultCurveName)
+            context.set_tmp_ecdh(_ecCurve)
+
+        except Exception:
             logger.exception("Failed to enable elliptic curve for TLS")
         context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
         context.use_certificate_chain_file(config.tls_certificate_file)
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index ec7711ba7d..aaa3efaca3 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -32,18 +32,25 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
     """Check whether the hash for this PDU matches the contents"""
     name, expected_hash = compute_content_hash(event, hash_algorithm)
     logger.debug("Expecting hash: %s", encode_base64(expected_hash))
-    if name not in event.hashes:
+
+    # some malformed events lack a 'hashes'. Protect against it being missing
+    # or a weird type by basically treating it the same as an unhashed event.
+    hashes = event.get("hashes")
+    if not isinstance(hashes, dict):
+        raise SynapseError(400, "Malformed 'hashes'", Codes.UNAUTHORIZED)
+
+    if name not in hashes:
         raise SynapseError(
             400,
             "Algorithm %s not in hashes %s" % (
-                name, list(event.hashes),
+                name, list(hashes),
             ),
             Codes.UNAUTHORIZED,
         )
-    message_hash_base64 = event.hashes[name]
+    message_hash_base64 = hashes[name]
     try:
         message_hash_bytes = decode_base64(message_hash_base64)
-    except:
+    except Exception:
         raise SynapseError(
             400,
             "Invalid base64: %s" % (message_hash_base64,),
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index c2bd64d6c2..f1fd488b90 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -13,14 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
+from synapse.util import logcontext
 from twisted.web.http import HTTPClient
 from twisted.internet.protocol import Factory
 from twisted.internet import defer, reactor
 from synapse.http.endpoint import matrix_federation_endpoint
-from synapse.util.logcontext import (
-    preserve_context_over_fn, preserve_context_over_deferred
-)
 import simplejson as json
 import logging
 
@@ -43,14 +40,10 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
 
     for i in range(5):
         try:
-            protocol = yield preserve_context_over_fn(
-                endpoint.connect, factory
-            )
-            server_response, server_certificate = yield preserve_context_over_deferred(
-                protocol.remote_key
-            )
-            defer.returnValue((server_response, server_certificate))
-            return
+            with logcontext.PreserveLoggingContext():
+                protocol = yield endpoint.connect(factory)
+                server_response, server_certificate = yield protocol.remote_key
+                defer.returnValue((server_response, server_certificate))
         except SynapseKeyClientError as e:
             logger.exception("Error getting key for %r" % (server_name,))
             if e.status.startswith("4"):
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index d7211ee9b3..22ee0fc93f 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,12 +16,11 @@
 
 from synapse.crypto.keyclient import fetch_server_key
 from synapse.api.errors import SynapseError, Codes
-from synapse.util.retryutils import get_retry_limiter
-from synapse.util import unwrapFirstError
-from synapse.util.async import ObservableDeferred
+from synapse.util import unwrapFirstError, logcontext
 from synapse.util.logcontext import (
-    preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
-    preserve_fn
+    PreserveLoggingContext,
+    preserve_fn,
+    run_in_background,
 )
 from synapse.util.metrics import Measure
 
@@ -58,7 +58,8 @@ Attributes:
     json_object(dict): The JSON object to verify.
     deferred(twisted.internet.defer.Deferred):
         A deferred (server_name, key_id, verify_key) tuple that resolves when
-        a verify key has been fetched
+        a verify key has been fetched. The deferreds' callbacks are run with no
+        logcontext.
 """
 
 
@@ -75,31 +76,41 @@ class Keyring(object):
         self.perspective_servers = self.config.perspectives
         self.hs = hs
 
+        # map from server name to Deferred. Has an entry for each server with
+        # an ongoing key download; the Deferred completes once the download
+        # completes.
+        #
+        # These are regular, logcontext-agnostic Deferreds.
         self.key_downloads = {}
 
     def verify_json_for_server(self, server_name, json_object):
-        return self.verify_json_objects_for_server(
-            [(server_name, json_object)]
-        )[0]
+        return logcontext.make_deferred_yieldable(
+            self.verify_json_objects_for_server(
+                [(server_name, json_object)]
+            )[0]
+        )
 
     def verify_json_objects_for_server(self, server_and_json):
-        """Bulk verfies signatures of json objects, bulk fetching keys as
+        """Bulk verifies signatures of json objects, bulk fetching keys as
         necessary.
 
         Args:
             server_and_json (list): List of pairs of (server_name, json_object)
 
         Returns:
-            list of deferreds indicating success or failure to verify each
-            json object's signature for the given server_name.
+            List<Deferred>: for each input pair, a deferred indicating success
+                or failure to verify each json object's signature for the given
+                server_name. The deferreds run their callbacks in the sentinel
+                logcontext.
         """
         verify_requests = []
 
         for server_name, json_object in server_and_json:
-            logger.debug("Verifying for %s", server_name)
 
             key_ids = signature_ids(json_object, server_name)
             if not key_ids:
+                logger.warn("Request from %s: no supported signature keys",
+                            server_name)
                 deferred = defer.fail(SynapseError(
                     400,
                     "Not signed with a supported algorithm",
@@ -108,76 +119,69 @@ class Keyring(object):
             else:
                 deferred = defer.Deferred()
 
+            logger.debug("Verifying for %s with key_ids %s",
+                         server_name, key_ids)
+
             verify_request = VerifyKeyRequest(
                 server_name, key_ids, json_object, deferred
             )
 
             verify_requests.append(verify_request)
 
-        @defer.inlineCallbacks
-        def handle_key_deferred(verify_request):
-            server_name = verify_request.server_name
-            try:
-                _, key_id, verify_key = yield verify_request.deferred
-            except IOError as e:
-                logger.warn(
-                    "Got IOError when downloading keys for %s: %s %s",
-                    server_name, type(e).__name__, str(e.message),
-                )
-                raise SynapseError(
-                    502,
-                    "Error downloading keys for %s" % (server_name,),
-                    Codes.UNAUTHORIZED,
-                )
-            except Exception as e:
-                logger.exception(
-                    "Got Exception when downloading keys for %s: %s %s",
-                    server_name, type(e).__name__, str(e.message),
-                )
-                raise SynapseError(
-                    401,
-                    "No key for %s with id %s" % (server_name, key_ids),
-                    Codes.UNAUTHORIZED,
-                )
+        run_in_background(self._start_key_lookups, verify_requests)
 
-            json_object = verify_request.json_object
+        # Pass those keys to handle_key_deferred so that the json object
+        # signatures can be verified
+        handle = preserve_fn(_handle_key_deferred)
+        return [
+            handle(rq) for rq in verify_requests
+        ]
 
-            try:
-                verify_signed_json(json_object, server_name, verify_key)
-            except:
-                raise SynapseError(
-                    401,
-                    "Invalid signature for server %s with key %s:%s" % (
-                        server_name, verify_key.alg, verify_key.version
-                    ),
-                    Codes.UNAUTHORIZED,
-                )
+    @defer.inlineCallbacks
+    def _start_key_lookups(self, verify_requests):
+        """Sets off the key fetches for each verify request
 
-        server_to_deferred = {
-            server_name: defer.Deferred()
-            for server_name, _ in server_and_json
-        }
+        Once each fetch completes, verify_request.deferred will be resolved.
 
-        with PreserveLoggingContext():
+        Args:
+            verify_requests (List[VerifyKeyRequest]):
+        """
+
+        try:
+            # create a deferred for each server we're going to look up the keys
+            # for; we'll resolve them once we have completed our lookups.
+            # These will be passed into wait_for_previous_lookups to block
+            # any other lookups until we have finished.
+            # The deferreds are called with no logcontext.
+            server_to_deferred = {
+                rq.server_name: defer.Deferred()
+                for rq in verify_requests
+            }
 
             # We want to wait for any previous lookups to complete before
             # proceeding.
-            wait_on_deferred = self.wait_for_previous_lookups(
-                [server_name for server_name, _ in server_and_json],
+            yield self.wait_for_previous_lookups(
+                [rq.server_name for rq in verify_requests],
                 server_to_deferred,
             )
 
             # Actually start fetching keys.
-            wait_on_deferred.addBoth(
-                lambda _: self.get_server_verify_keys(verify_requests)
-            )
+            self._get_server_verify_keys(verify_requests)
 
             # When we've finished fetching all the keys for a given server_name,
             # resolve the deferred passed to `wait_for_previous_lookups` so that
             # any lookups waiting will proceed.
+            #
+            # map from server name to a set of request ids
             server_to_request_ids = {}
 
-            def remove_deferreds(res, server_name, verify_request):
+            for verify_request in verify_requests:
+                server_name = verify_request.server_name
+                request_id = id(verify_request)
+                server_to_request_ids.setdefault(server_name, set()).add(request_id)
+
+            def remove_deferreds(res, verify_request):
+                server_name = verify_request.server_name
                 request_id = id(verify_request)
                 server_to_request_ids[server_name].discard(request_id)
                 if not server_to_request_ids[server_name]:
@@ -187,17 +191,11 @@ class Keyring(object):
                 return res
 
             for verify_request in verify_requests:
-                server_name = verify_request.server_name
-                request_id = id(verify_request)
-                server_to_request_ids.setdefault(server_name, set()).add(request_id)
-                deferred.addBoth(remove_deferreds, server_name, verify_request)
-
-        # Pass those keys to handle_key_deferred so that the json object
-        # signatures can be verified
-        return [
-            preserve_context_over_fn(handle_key_deferred, verify_request)
-            for verify_request in verify_requests
-        ]
+                verify_request.deferred.addBoth(
+                    remove_deferreds, verify_request,
+                )
+        except Exception:
+            logger.exception("Error starting key lookups")
 
     @defer.inlineCallbacks
     def wait_for_previous_lookups(self, server_names, server_to_deferred):
@@ -206,7 +204,13 @@ class Keyring(object):
         Args:
             server_names (list): list of server_names we want to lookup
             server_to_deferred (dict): server_name to deferred which gets
-                resolved once we've finished looking up keys for that server
+                resolved once we've finished looking up keys for that server.
+                The Deferreds should be regular twisted ones which call their
+                callbacks with no logcontext.
+
+        Returns: a Deferred which resolves once all key lookups for the given
+            servers have completed. Follows the synapse rules of logcontext
+            preservation.
         """
         while True:
             wait_on = [
@@ -220,19 +224,23 @@ class Keyring(object):
             else:
                 break
 
+        def rm(r, server_name_):
+            self.key_downloads.pop(server_name_, None)
+            return r
+
         for server_name, deferred in server_to_deferred.items():
-            d = ObservableDeferred(preserve_context_over_deferred(deferred))
-            self.key_downloads[server_name] = d
+            self.key_downloads[server_name] = deferred
+            deferred.addBoth(rm, server_name)
 
-            def rm(r, server_name):
-                self.key_downloads.pop(server_name, None)
-                return r
+    def _get_server_verify_keys(self, verify_requests):
+        """Tries to find at least one key for each verify request
 
-            d.addBoth(rm, server_name)
+        For each verify_request, verify_request.deferred is called back with
+        params (server_name, key_id, VerifyKey) if a key is found, or errbacked
+        with a SynapseError if none of the keys are found.
 
-    def get_server_verify_keys(self, verify_requests):
-        """Takes a dict of KeyGroups and tries to find at least one key for
-        each group.
+        Args:
+            verify_requests (list[VerifyKeyRequest]): list of verify requests
         """
 
         # These are functions that produce keys given a list of key ids
@@ -245,8 +253,11 @@ class Keyring(object):
         @defer.inlineCallbacks
         def do_iterations():
             with Measure(self.clock, "get_server_verify_keys"):
+                # dict[str, dict[str, VerifyKey]]: results so far.
+                # map server_name -> key_id -> VerifyKey
                 merged_results = {}
 
+                # dict[str, set(str)]: keys to fetch for each server
                 missing_keys = {}
                 for verify_request in verify_requests:
                     missing_keys.setdefault(verify_request.server_name, set()).update(
@@ -290,33 +301,46 @@ class Keyring(object):
                     if not missing_keys:
                         break
 
-                for verify_request in requests_missing_keys.values():
-                    verify_request.deferred.errback(SynapseError(
-                        401,
-                        "No key for %s with id %s" % (
-                            verify_request.server_name, verify_request.key_ids,
-                        ),
-                        Codes.UNAUTHORIZED,
-                    ))
+                with PreserveLoggingContext():
+                    for verify_request in requests_missing_keys:
+                        verify_request.deferred.errback(SynapseError(
+                            401,
+                            "No key for %s with id %s" % (
+                                verify_request.server_name, verify_request.key_ids,
+                            ),
+                            Codes.UNAUTHORIZED,
+                        ))
 
         def on_err(err):
-            for verify_request in verify_requests:
-                if not verify_request.deferred.called:
-                    verify_request.deferred.errback(err)
+            with PreserveLoggingContext():
+                for verify_request in verify_requests:
+                    if not verify_request.deferred.called:
+                        verify_request.deferred.errback(err)
 
-        do_iterations().addErrback(on_err)
+        run_in_background(do_iterations).addErrback(on_err)
 
     @defer.inlineCallbacks
     def get_keys_from_store(self, server_name_and_key_ids):
-        res = yield preserve_context_over_deferred(defer.gatherResults(
+        """
+
+        Args:
+            server_name_and_key_ids (list[(str, iterable[str])]):
+                list of (server_name, iterable[key_id]) tuples to fetch keys for
+
+        Returns:
+            Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from
+                server_name -> key_id -> VerifyKey
+        """
+        res = yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
-                preserve_fn(self.store.get_server_verify_keys)(
-                    server_name, key_ids
+                run_in_background(
+                    self.store.get_server_verify_keys,
+                    server_name, key_ids,
                 ).addCallback(lambda ks, server: (server, ks), server_name)
                 for server_name, key_ids in server_name_and_key_ids
             ],
             consumeErrors=True,
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
 
         defer.returnValue(dict(res))
 
@@ -333,17 +357,17 @@ class Keyring(object):
                 logger.exception(
                     "Unable to get key from %r: %s %s",
                     perspective_name,
-                    type(e).__name__, str(e.message),
+                    type(e).__name__, str(e),
                 )
                 defer.returnValue({})
 
-        results = yield preserve_context_over_deferred(defer.gatherResults(
+        results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
-                preserve_fn(get_key)(p_name, p_keys)
+                run_in_background(get_key, p_name, p_keys)
                 for p_name, p_keys in self.perspective_servers.items()
             ],
             consumeErrors=True,
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
 
         union_of_keys = {}
         for result in results:
@@ -356,40 +380,34 @@ class Keyring(object):
     def get_keys_from_server(self, server_name_and_key_ids):
         @defer.inlineCallbacks
         def get_key(server_name, key_ids):
-            limiter = yield get_retry_limiter(
-                server_name,
-                self.clock,
-                self.store,
-            )
-            with limiter:
-                keys = None
-                try:
-                    keys = yield self.get_server_verify_key_v2_direct(
-                        server_name, key_ids
-                    )
-                except Exception as e:
-                    logger.info(
-                        "Unable to get key %r for %r directly: %s %s",
-                        key_ids, server_name,
-                        type(e).__name__, str(e.message),
-                    )
+            keys = None
+            try:
+                keys = yield self.get_server_verify_key_v2_direct(
+                    server_name, key_ids
+                )
+            except Exception as e:
+                logger.info(
+                    "Unable to get key %r for %r directly: %s %s",
+                    key_ids, server_name,
+                    type(e).__name__, str(e),
+                )
 
-                if not keys:
-                    keys = yield self.get_server_verify_key_v1_direct(
-                        server_name, key_ids
-                    )
+            if not keys:
+                keys = yield self.get_server_verify_key_v1_direct(
+                    server_name, key_ids
+                )
 
-                    keys = {server_name: keys}
+                keys = {server_name: keys}
 
             defer.returnValue(keys)
 
-        results = yield preserve_context_over_deferred(defer.gatherResults(
+        results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
-                preserve_fn(get_key)(server_name, key_ids)
+                run_in_background(get_key, server_name, key_ids)
                 for server_name, key_ids in server_name_and_key_ids
             ],
             consumeErrors=True,
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
 
         merged = {}
         for result in results:
@@ -466,9 +484,10 @@ class Keyring(object):
             for server_name, response_keys in processed_response.items():
                 keys.setdefault(server_name, {}).update(response_keys)
 
-        yield preserve_context_over_deferred(defer.gatherResults(
+        yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
-                preserve_fn(self.store_keys)(
+                run_in_background(
+                    self.store_keys,
                     server_name=server_name,
                     from_server=perspective_name,
                     verify_keys=response_keys,
@@ -476,7 +495,7 @@ class Keyring(object):
                 for server_name, response_keys in keys.items()
             ],
             consumeErrors=True
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
 
         defer.returnValue(keys)
 
@@ -524,9 +543,10 @@ class Keyring(object):
 
             keys.update(response_keys)
 
-        yield preserve_context_over_deferred(defer.gatherResults(
+        yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
-                preserve_fn(self.store_keys)(
+                run_in_background(
+                    self.store_keys,
                     server_name=key_server_name,
                     from_server=server_name,
                     verify_keys=verify_keys,
@@ -534,7 +554,7 @@ class Keyring(object):
                 for key_server_name, verify_keys in keys.items()
             ],
             consumeErrors=True
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
 
         defer.returnValue(keys)
 
@@ -600,9 +620,10 @@ class Keyring(object):
         response_keys.update(verify_keys)
         response_keys.update(old_verify_keys)
 
-        yield preserve_context_over_deferred(defer.gatherResults(
+        yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
-                preserve_fn(self.store.store_server_keys_json)(
+                run_in_background(
+                    self.store.store_server_keys_json,
                     server_name=server_name,
                     key_id=key_id,
                     from_server=server_name,
@@ -613,7 +634,7 @@ class Keyring(object):
                 for key_id in updated_key_ids
             ],
             consumeErrors=True,
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
 
         results[server_name] = response_keys
 
@@ -691,7 +712,6 @@ class Keyring(object):
 
         defer.returnValue(verify_keys)
 
-    @defer.inlineCallbacks
     def store_keys(self, server_name, from_server, verify_keys):
         """Store a collection of verify keys for a given server
         Args:
@@ -702,12 +722,57 @@ class Keyring(object):
             A deferred that completes when the keys are stored.
         """
         # TODO(markjh): Store whether the keys have expired.
-        yield preserve_context_over_deferred(defer.gatherResults(
+        return logcontext.make_deferred_yieldable(defer.gatherResults(
             [
-                preserve_fn(self.store.store_server_verify_key)(
+                run_in_background(
+                    self.store.store_server_verify_key,
                     server_name, server_name, key.time_added, key
                 )
                 for key_id, key in verify_keys.items()
             ],
             consumeErrors=True,
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
+
+
+@defer.inlineCallbacks
+def _handle_key_deferred(verify_request):
+    server_name = verify_request.server_name
+    try:
+        with PreserveLoggingContext():
+            _, key_id, verify_key = yield verify_request.deferred
+    except IOError as e:
+        logger.warn(
+            "Got IOError when downloading keys for %s: %s %s",
+            server_name, type(e).__name__, str(e),
+        )
+        raise SynapseError(
+            502,
+            "Error downloading keys for %s" % (server_name,),
+            Codes.UNAUTHORIZED,
+        )
+    except Exception as e:
+        logger.exception(
+            "Got Exception when downloading keys for %s: %s %s",
+            server_name, type(e).__name__, str(e),
+        )
+        raise SynapseError(
+            401,
+            "No key for %s with id %s" % (server_name, verify_request.key_ids),
+            Codes.UNAUTHORIZED,
+        )
+
+    json_object = verify_request.json_object
+
+    logger.debug("Got key %s %s:%s for server %s, verifying" % (
+        key_id, verify_key.alg, verify_key.version, server_name,
+    ))
+    try:
+        verify_signed_json(json_object, server_name, verify_key)
+    except Exception:
+        raise SynapseError(
+            401,
+            "Invalid signature for server %s with key %s:%s" % (
+                server_name, verify_key.alg, verify_key.version
+            ),
+            Codes.UNAUTHORIZED,
+        )
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 4096c606f1..cd5627e36a 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -319,7 +319,7 @@ def _is_membership_change_allowed(event, auth_events):
         # TODO (erikj): Implement kicks.
         if target_banned and user_level < ban_level:
             raise AuthError(
-                403, "You cannot unban user &s." % (target_user_id,)
+                403, "You cannot unban user %s." % (target_user_id,)
             )
         elif target_user_id != event.user_id:
             kick_level = _get_named_level(auth_events, "kick", 50)
@@ -443,12 +443,12 @@ def _check_power_levels(event, auth_events):
     for k, v in user_list.items():
         try:
             UserID.from_string(k)
-        except:
+        except Exception:
             raise SynapseError(400, "Not a valid user_id: %s" % (k,))
 
         try:
             int(v)
-        except:
+        except Exception:
             raise SynapseError(400, "Not a valid power level: %s" % (v,))
 
     key = (event.type, event.state_key, )
@@ -470,14 +470,14 @@ def _check_power_levels(event, auth_events):
         ("invite", None),
     ]
 
-    old_list = current_state.content.get("users")
+    old_list = current_state.content.get("users", {})
     for user in set(old_list.keys() + user_list.keys()):
         levels_to_check.append(
             (user, "users")
         )
 
-    old_list = current_state.content.get("events")
-    new_list = event.content.get("events")
+    old_list = current_state.content.get("events", {})
+    new_list = event.content.get("events", {})
     for ev_id in set(old_list.keys() + new_list.keys()):
         levels_to_check.append(
             (ev_id, "events")
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index e673e96cc0..c3ff85c49a 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -47,14 +47,26 @@ class _EventInternalMetadata(object):
 
 
 def _event_dict_property(key):
+    # We want to be able to use hasattr with the event dict properties.
+    # However, (on python3) hasattr expects AttributeError to be raised. Hence,
+    # we need to transform the KeyError into an AttributeError
     def getter(self):
-        return self._event_dict[key]
+        try:
+            return self._event_dict[key]
+        except KeyError:
+            raise AttributeError(key)
 
     def setter(self, v):
-        self._event_dict[key] = v
+        try:
+            self._event_dict[key] = v
+        except KeyError:
+            raise AttributeError(key)
 
     def delete(self):
-        del self._event_dict[key]
+        try:
+            del self._event_dict[key]
+        except KeyError:
+            raise AttributeError(key)
 
     return property(
         getter,
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 365fd96bd2..13fbba68c0 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -55,7 +55,7 @@ class EventBuilderFactory(object):
 
         local_part = str(int(self.clock.time())) + i + random_string(5)
 
-        e_id = EventID.create(local_part, self.hostname)
+        e_id = EventID(local_part, self.hostname)
 
         return e_id.to_string()
 
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 11605b34a3..8e684d91b5 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -13,17 +13,49 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
+
+from frozendict import frozendict
+
 
 class EventContext(object):
+    """
+    Attributes:
+        current_state_ids (dict[(str, str), str]):
+            The current state map including the current event.
+            (type, state_key) -> event_id
+
+        prev_state_ids (dict[(str, str), str]):
+            The current state map excluding the current event.
+            (type, state_key) -> event_id
+
+        state_group (int|None): state group id, if the state has been stored
+            as a state group. This is usually only None if e.g. the event is
+            an outlier.
+        rejected (bool|str): A rejection reason if the event was rejected, else
+            False
+
+        push_actions (list[(str, list[object])]): list of (user_id, actions)
+            tuples
+
+        prev_group (int): Previously persisted state group. ``None`` for an
+            outlier.
+        delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
+            (type, state_key) -> event_id. ``None`` for an outlier.
+
+        prev_state_events (?): XXX: is this ever set to anything other than
+            the empty list?
+    """
+
     __slots__ = [
         "current_state_ids",
         "prev_state_ids",
         "state_group",
         "rejected",
-        "push_actions",
         "prev_group",
         "delta_ids",
         "prev_state_events",
+        "app_service",
     ]
 
     def __init__(self):
@@ -34,7 +66,6 @@ class EventContext(object):
         self.state_group = None
 
         self.rejected = False
-        self.push_actions = []
 
         # A previously persisted state group and a delta between that
         # and this state.
@@ -42,3 +73,100 @@ class EventContext(object):
         self.delta_ids = None
 
         self.prev_state_events = None
+
+        self.app_service = None
+
+    def serialize(self, event):
+        """Converts self to a type that can be serialized as JSON, and then
+        deserialized by `deserialize`
+
+        Args:
+            event (FrozenEvent): The event that this context relates to
+
+        Returns:
+            dict
+        """
+
+        # We don't serialize the full state dicts, instead they get pulled out
+        # of the DB on the other side. However, the other side can't figure out
+        # the prev_state_ids, so if we're a state event we include the event
+        # id that we replaced in the state.
+        if event.is_state():
+            prev_state_id = self.prev_state_ids.get((event.type, event.state_key))
+        else:
+            prev_state_id = None
+
+        return {
+            "prev_state_id": prev_state_id,
+            "event_type": event.type,
+            "event_state_key": event.state_key if event.is_state() else None,
+            "state_group": self.state_group,
+            "rejected": self.rejected,
+            "prev_group": self.prev_group,
+            "delta_ids": _encode_state_dict(self.delta_ids),
+            "prev_state_events": self.prev_state_events,
+            "app_service_id": self.app_service.id if self.app_service else None
+        }
+
+    @staticmethod
+    @defer.inlineCallbacks
+    def deserialize(store, input):
+        """Converts a dict that was produced by `serialize` back into a
+        EventContext.
+
+        Args:
+            store (DataStore): Used to convert AS ID to AS object
+            input (dict): A dict produced by `serialize`
+
+        Returns:
+            EventContext
+        """
+        context = EventContext()
+        context.state_group = input["state_group"]
+        context.rejected = input["rejected"]
+        context.prev_group = input["prev_group"]
+        context.delta_ids = _decode_state_dict(input["delta_ids"])
+        context.prev_state_events = input["prev_state_events"]
+
+        # We use the state_group and prev_state_id stuff to pull the
+        # current_state_ids out of the DB and construct prev_state_ids.
+        prev_state_id = input["prev_state_id"]
+        event_type = input["event_type"]
+        event_state_key = input["event_state_key"]
+
+        context.current_state_ids = yield store.get_state_ids_for_group(
+            context.state_group,
+        )
+        if prev_state_id and event_state_key:
+            context.prev_state_ids = dict(context.current_state_ids)
+            context.prev_state_ids[(event_type, event_state_key)] = prev_state_id
+        else:
+            context.prev_state_ids = context.current_state_ids
+
+        app_service_id = input["app_service_id"]
+        if app_service_id:
+            context.app_service = store.get_app_service_by_id(app_service_id)
+
+        defer.returnValue(context)
+
+
+def _encode_state_dict(state_dict):
+    """Since dicts of (type, state_key) -> event_id cannot be serialized in
+    JSON we need to convert them to a form that can.
+    """
+    if state_dict is None:
+        return None
+
+    return [
+        (etype, state_key, v)
+        for (etype, state_key), v in state_dict.iteritems()
+    ]
+
+
+def _decode_state_dict(input):
+    """Decodes a state dict encoded using `_encode_state_dict` above
+    """
+    if input is None:
+        return None
+
+    return frozendict({(etype, state_key,): v for etype, state_key, v in input})
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
new file mode 100644
index 0000000000..633e068eb8
--- /dev/null
+++ b/synapse/events/spamcheck.py
@@ -0,0 +1,113 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class SpamChecker(object):
+    def __init__(self, hs):
+        self.spam_checker = None
+
+        module = None
+        config = None
+        try:
+            module, config = hs.config.spam_checker
+        except Exception:
+            pass
+
+        if module is not None:
+            self.spam_checker = module(config=config)
+
+    def check_event_for_spam(self, event):
+        """Checks if a given event is considered "spammy" by this server.
+
+        If the server considers an event spammy, then it will be rejected if
+        sent by a local user. If it is sent by a user on another server, then
+        users receive a blank event.
+
+        Args:
+            event (synapse.events.EventBase): the event to be checked
+
+        Returns:
+            bool: True if the event is spammy.
+        """
+        if self.spam_checker is None:
+            return False
+
+        return self.spam_checker.check_event_for_spam(event)
+
+    def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+        """Checks if a given user may send an invite
+
+        If this method returns false, the invite will be rejected.
+
+        Args:
+            userid (string): The sender's user ID
+
+        Returns:
+            bool: True if the user may send an invite, otherwise False
+        """
+        if self.spam_checker is None:
+            return True
+
+        return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
+
+    def user_may_create_room(self, userid):
+        """Checks if a given user may create a room
+
+        If this method returns false, the creation request will be rejected.
+
+        Args:
+            userid (string): The sender's user ID
+
+        Returns:
+            bool: True if the user may create a room, otherwise False
+        """
+        if self.spam_checker is None:
+            return True
+
+        return self.spam_checker.user_may_create_room(userid)
+
+    def user_may_create_room_alias(self, userid, room_alias):
+        """Checks if a given user may create a room alias
+
+        If this method returns false, the association request will be rejected.
+
+        Args:
+            userid (string): The sender's user ID
+            room_alias (string): The alias to be created
+
+        Returns:
+            bool: True if the user may create a room alias, otherwise False
+        """
+        if self.spam_checker is None:
+            return True
+
+        return self.spam_checker.user_may_create_room_alias(userid, room_alias)
+
+    def user_may_publish_room(self, userid, room_id):
+        """Checks if a given user may publish a room to the directory
+
+        If this method returns false, the publish request will be rejected.
+
+        Args:
+            userid (string): The sender's user ID
+            room_id (string): The ID of the room that would be published
+
+        Returns:
+            bool: True if the user may publish the room, otherwise False
+        """
+        if self.spam_checker is None:
+            return True
+
+        return self.spam_checker.user_may_publish_room(userid, room_id)
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 5bbaef8187..824f4a42e3 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -225,7 +225,22 @@ def format_event_for_client_v2_without_room_id(d):
 
 def serialize_event(e, time_now_ms, as_client_event=True,
                     event_format=format_event_for_client_v1,
-                    token_id=None, only_event_fields=None):
+                    token_id=None, only_event_fields=None, is_invite=False):
+    """Serialize event for clients
+
+    Args:
+        e (EventBase)
+        time_now_ms (int)
+        as_client_event (bool)
+        event_format
+        token_id
+        only_event_fields
+        is_invite (bool): Whether this is an invite that is being sent to the
+            invitee
+
+    Returns:
+        dict
+    """
     # FIXME(erikj): To handle the case of presence events and the like
     if not isinstance(e, EventBase):
         return e
@@ -251,6 +266,12 @@ def serialize_event(e, time_now_ms, as_client_event=True,
             if txn_id is not None:
                 d["unsigned"]["transaction_id"] = txn_id
 
+    # If this is an invite for somebody else, then we don't care about the
+    # invite_room_state as that's meant solely for the invitee. Other clients
+    # will already have the state since they're in the room.
+    if not is_invite:
+        d["unsigned"].pop("invite_room_state", None)
+
     if as_client_event:
         d = event_format(d)
 
diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py
index 2e32d245ba..f5f0bdfca3 100644
--- a/synapse/federation/__init__.py
+++ b/synapse/federation/__init__.py
@@ -15,11 +15,3 @@
 
 """ This package includes all the federation specific logic.
 """
-
-from .replication import ReplicationLayer
-
-
-def initialize_http_replication(hs):
-    transport = hs.get_federation_transport_client()
-
-    return ReplicationLayer(hs, transport)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 2339cc9034..4cc98a3fe8 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -12,28 +12,31 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
 
+import six
 
-from twisted.internet import defer
-
-from synapse.events.utils import prune_event
-
+from synapse.api.constants import MAX_DEPTH
+from synapse.api.errors import SynapseError, Codes
 from synapse.crypto.event_signing import check_event_content_hash
-
-from synapse.api.errors import SynapseError
-
-from synapse.util import unwrapFirstError
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-
-import logging
-
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
+from synapse.http.servlet import assert_params_in_request
+from synapse.util import unwrapFirstError, logcontext
+from twisted.internet import defer
 
 logger = logging.getLogger(__name__)
 
 
 class FederationBase(object):
     def __init__(self, hs):
-        pass
+        self.hs = hs
+
+        self.server_name = hs.hostname
+        self.keyring = hs.get_keyring()
+        self.spam_checker = hs.get_spam_checker()
+        self.store = hs.get_datastore()
+        self._clock = hs.get_clock()
 
     @defer.inlineCallbacks
     def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
@@ -57,56 +60,52 @@ class FederationBase(object):
         """
         deferreds = self._check_sigs_and_hashes(pdus)
 
-        def callback(pdu):
-            return pdu
-
-        def errback(failure, pdu):
-            failure.trap(SynapseError)
-            return None
+        @defer.inlineCallbacks
+        def handle_check_result(pdu, deferred):
+            try:
+                res = yield logcontext.make_deferred_yieldable(deferred)
+            except SynapseError:
+                res = None
 
-        def try_local_db(res, pdu):
             if not res:
                 # Check local db.
-                return self.store.get_event(
+                res = yield self.store.get_event(
                     pdu.event_id,
                     allow_rejected=True,
                     allow_none=True,
                 )
-            return res
 
-        def try_remote(res, pdu):
             if not res and pdu.origin != origin:
-                return self.get_pdu(
-                    destinations=[pdu.origin],
-                    event_id=pdu.event_id,
-                    outlier=outlier,
-                    timeout=10000,
-                ).addErrback(lambda e: None)
-            return res
-
-        def warn(res, pdu):
+                try:
+                    res = yield self.get_pdu(
+                        destinations=[pdu.origin],
+                        event_id=pdu.event_id,
+                        outlier=outlier,
+                        timeout=10000,
+                    )
+                except SynapseError:
+                    pass
+
             if not res:
                 logger.warn(
                     "Failed to find copy of %s with valid signature",
                     pdu.event_id,
                 )
-            return res
 
-        for pdu, deferred in zip(pdus, deferreds):
-            deferred.addCallbacks(
-                callback, errback, errbackArgs=[pdu]
-            ).addCallback(
-                try_local_db, pdu
-            ).addCallback(
-                try_remote, pdu
-            ).addCallback(
-                warn, pdu
-            )
+            defer.returnValue(res)
 
-        valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
-            deferreds,
-            consumeErrors=True
-        )).addErrback(unwrapFirstError)
+        handle = logcontext.preserve_fn(handle_check_result)
+        deferreds2 = [
+            handle(pdu, deferred)
+            for pdu, deferred in zip(pdus, deferreds)
+        ]
+
+        valid_pdus = yield logcontext.make_deferred_yieldable(
+            defer.gatherResults(
+                deferreds2,
+                consumeErrors=True,
+            )
+        ).addErrback(unwrapFirstError)
 
         if include_none:
             defer.returnValue(valid_pdus)
@@ -114,15 +113,24 @@ class FederationBase(object):
             defer.returnValue([p for p in valid_pdus if p])
 
     def _check_sigs_and_hash(self, pdu):
-        return self._check_sigs_and_hashes([pdu])[0]
+        return logcontext.make_deferred_yieldable(
+            self._check_sigs_and_hashes([pdu])[0],
+        )
 
     def _check_sigs_and_hashes(self, pdus):
-        """Throws a SynapseError if a PDU does not have the correct
-        signatures.
+        """Checks that each of the received events is correctly signed by the
+        sending server.
+
+        Args:
+            pdus (list[FrozenEvent]): the events to be checked
 
         Returns:
-            FrozenEvent: Either the given event or it redacted if it failed the
-            content hash check.
+            list[Deferred]: for each input event, a deferred which:
+              * returns the original event if the checks pass
+              * returns a redacted version of the event (if the signature
+                matched but the hash did not)
+              * throws a SynapseError if the signature check failed.
+            The deferreds run their callbacks in the sentinel logcontext.
         """
 
         redacted_pdus = [
@@ -130,26 +138,38 @@ class FederationBase(object):
             for pdu in pdus
         ]
 
-        deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([
+        deferreds = self.keyring.verify_json_objects_for_server([
             (p.origin, p.get_pdu_json())
             for p in redacted_pdus
         ])
 
+        ctx = logcontext.LoggingContext.current_context()
+
         def callback(_, pdu, redacted):
-            if not check_event_content_hash(pdu):
-                logger.warn(
-                    "Event content has been tampered, redacting %s: %s",
-                    pdu.event_id, pdu.get_pdu_json()
-                )
-                return redacted
-            return pdu
+            with logcontext.PreserveLoggingContext(ctx):
+                if not check_event_content_hash(pdu):
+                    logger.warn(
+                        "Event content has been tampered, redacting %s: %s",
+                        pdu.event_id, pdu.get_pdu_json()
+                    )
+                    return redacted
+
+                if self.spam_checker.check_event_for_spam(pdu):
+                    logger.warn(
+                        "Event contains spam, redacting %s: %s",
+                        pdu.event_id, pdu.get_pdu_json()
+                    )
+                    return redacted
+
+                return pdu
 
         def errback(failure, pdu):
             failure.trap(SynapseError)
-            logger.warn(
-                "Signature check failed for %s",
-                pdu.event_id,
-            )
+            with logcontext.PreserveLoggingContext(ctx):
+                logger.warn(
+                    "Signature check failed for %s",
+                    pdu.event_id,
+                )
             return failure
 
         for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
@@ -160,3 +180,40 @@ class FederationBase(object):
             )
 
         return deferreds
+
+
+def event_from_pdu_json(pdu_json, outlier=False):
+    """Construct a FrozenEvent from an event json received over federation
+
+    Args:
+        pdu_json (object): pdu as received over federation
+        outlier (bool): True to mark this event as an outlier
+
+    Returns:
+        FrozenEvent
+
+    Raises:
+        SynapseError: if the pdu is missing required fields or is otherwise
+            not a valid matrix event
+    """
+    # we could probably enforce a bunch of other fields here (room_id, sender,
+    # origin, etc etc)
+    assert_params_in_request(pdu_json, ('event_id', 'type', 'depth'))
+
+    depth = pdu_json['depth']
+    if not isinstance(depth, six.integer_types):
+        raise SynapseError(400, "Depth %r not an intger" % (depth, ),
+                           Codes.BAD_JSON)
+
+    if depth < 0:
+        raise SynapseError(400, "Depth too small", Codes.BAD_JSON)
+    elif depth > MAX_DEPTH:
+        raise SynapseError(400, "Depth too large", Codes.BAD_JSON)
+
+    event = FrozenEvent(
+        pdu_json
+    )
+
+    event.internal_metadata.outlier = outlier
+
+    return event
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index b5bcfd705a..6163f7c466 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -14,28 +14,30 @@
 # limitations under the License.
 
 
+import copy
+import itertools
+import logging
+import random
+
+from six.moves import range
+
 from twisted.internet import defer
 
-from .federation_base import FederationBase
 from synapse.api.constants import Membership
-
 from synapse.api.errors import (
-    CodeMessageException, HttpResponseException, SynapseError,
+    CodeMessageException, HttpResponseException, SynapseError, FederationDeniedError
 )
-from synapse.util import unwrapFirstError
+from synapse.events import builder
+from synapse.federation.federation_base import (
+    FederationBase,
+    event_from_pdu_json,
+)
+import synapse.metrics
+from synapse.util import logcontext, unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
 from synapse.util.logutils import log_function
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-from synapse.events import FrozenEvent, builder
-import synapse.metrics
-
-from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
-
-import copy
-import itertools
-import logging
-import random
-
+from synapse.util.retryutils import NotRetryingDestination
 
 logger = logging.getLogger(__name__)
 
@@ -58,6 +60,7 @@ class FederationClient(FederationBase):
             self._clear_tried_cache, 60 * 1000,
         )
         self.state = hs.get_state_handler()
+        self.transport_layer = hs.get_federation_transport_client()
 
     def _clear_tried_cache(self):
         """Clear pdu_destination_tried cache"""
@@ -88,7 +91,7 @@ class FederationClient(FederationBase):
 
     @log_function
     def make_query(self, destination, query_type, args,
-                   retry_on_dns_fail=False):
+                   retry_on_dns_fail=False, ignore_backoff=False):
         """Sends a federation Query to a remote homeserver of the given type
         and arguments.
 
@@ -98,6 +101,8 @@ class FederationClient(FederationBase):
                 handler name used in register_query_handler().
             args (dict): Mapping of strings to strings containing the details
                 of the query request.
+            ignore_backoff (bool): true to ignore the historical backoff data
+                and try the request anyway.
 
         Returns:
             a Deferred which will eventually yield a JSON object from the
@@ -106,7 +111,8 @@ class FederationClient(FederationBase):
         sent_queries_counter.inc(query_type)
 
         return self.transport_layer.make_query(
-            destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
+            destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
+            ignore_backoff=ignore_backoff,
         )
 
     @log_function
@@ -181,15 +187,15 @@ class FederationClient(FederationBase):
         logger.debug("backfill transaction_data=%s", repr(transaction_data))
 
         pdus = [
-            self.event_from_pdu_json(p, outlier=False)
+            event_from_pdu_json(p, outlier=False)
             for p in transaction_data["pdus"]
         ]
 
         # FIXME: We should handle signature failures more gracefully.
-        pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
+        pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
             self._check_sigs_and_hashes(pdus),
             consumeErrors=True,
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
 
         defer.returnValue(pdus)
 
@@ -206,8 +212,7 @@ class FederationClient(FederationBase):
 
         Args:
             destinations (list): Which home servers to query
-            pdu_origin (str): The home server that originally sent the pdu.
-            event_id (str)
+            event_id (str): event to fetch
             outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
                 it's from an arbitary point in the context as opposed to part
                 of the current block of PDUs. Defaults to `False`
@@ -235,31 +240,24 @@ class FederationClient(FederationBase):
                 continue
 
             try:
-                limiter = yield get_retry_limiter(
-                    destination,
-                    self._clock,
-                    self.store,
+                transaction_data = yield self.transport_layer.get_event(
+                    destination, event_id, timeout=timeout,
                 )
 
-                with limiter:
-                    transaction_data = yield self.transport_layer.get_event(
-                        destination, event_id, timeout=timeout,
-                    )
-
-                    logger.debug("transaction_data %r", transaction_data)
+                logger.debug("transaction_data %r", transaction_data)
 
-                    pdu_list = [
-                        self.event_from_pdu_json(p, outlier=outlier)
-                        for p in transaction_data["pdus"]
-                    ]
+                pdu_list = [
+                    event_from_pdu_json(p, outlier=outlier)
+                    for p in transaction_data["pdus"]
+                ]
 
-                    if pdu_list and pdu_list[0]:
-                        pdu = pdu_list[0]
+                if pdu_list and pdu_list[0]:
+                    pdu = pdu_list[0]
 
-                        # Check signatures are correct.
-                        signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
+                    # Check signatures are correct.
+                    signed_pdu = yield self._check_sigs_and_hash(pdu)
 
-                        break
+                    break
 
                 pdu_attempts[destination] = now
 
@@ -271,6 +269,9 @@ class FederationClient(FederationBase):
             except NotRetryingDestination as e:
                 logger.info(e.message)
                 continue
+            except FederationDeniedError as e:
+                logger.info(e.message)
+                continue
             except Exception as e:
                 pdu_attempts[destination] = now
 
@@ -341,11 +342,11 @@ class FederationClient(FederationBase):
         )
 
         pdus = [
-            self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
+            event_from_pdu_json(p, outlier=True) for p in result["pdus"]
         ]
 
         auth_chain = [
-            self.event_from_pdu_json(p, outlier=True)
+            event_from_pdu_json(p, outlier=True)
             for p in result.get("auth_chain", [])
         ]
 
@@ -395,7 +396,7 @@ class FederationClient(FederationBase):
             seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
             signed_events = seen_events.values()
         else:
-            seen_events = yield self.store.have_events(event_ids)
+            seen_events = yield self.store.have_seen_events(event_ids)
             signed_events = []
 
         failed_to_fetch = set()
@@ -414,18 +415,19 @@ class FederationClient(FederationBase):
 
         batch_size = 20
         missing_events = list(missing_events)
-        for i in xrange(0, len(missing_events), batch_size):
+        for i in range(0, len(missing_events), batch_size):
             batch = set(missing_events[i:i + batch_size])
 
             deferreds = [
-                preserve_fn(self.get_pdu)(
+                run_in_background(
+                    self.get_pdu,
                     destinations=random_server_list(),
                     event_id=e_id,
                 )
                 for e_id in batch
             ]
 
-            res = yield preserve_context_over_deferred(
+            res = yield make_deferred_yieldable(
                 defer.DeferredList(deferreds, consumeErrors=True)
             )
             for success, result in res:
@@ -446,7 +448,7 @@ class FederationClient(FederationBase):
         )
 
         auth_chain = [
-            self.event_from_pdu_json(p, outlier=True)
+            event_from_pdu_json(p, outlier=True)
             for p in res["auth_chain"]
         ]
 
@@ -479,8 +481,13 @@ class FederationClient(FederationBase):
             content (object): Any additional data to put into the content field
                 of the event.
         Return:
-            A tuple of (origin (str), event (object)) where origin is the remote
-            homeserver which generated the event.
+            Deferred: resolves to a tuple of (origin (str), event (object))
+            where origin is the remote homeserver which generated the event.
+
+            Fails with a ``CodeMessageException`` if the chosen remote server
+            returns a 300/400 code.
+
+            Fails with a ``RuntimeError`` if no servers were reachable.
         """
         valid_memberships = {Membership.JOIN, Membership.LEAVE}
         if membership not in valid_memberships:
@@ -533,6 +540,27 @@ class FederationClient(FederationBase):
 
     @defer.inlineCallbacks
     def send_join(self, destinations, pdu):
+        """Sends a join event to one of a list of homeservers.
+
+        Doing so will cause the remote server to add the event to the graph,
+        and send the event out to the rest of the federation.
+
+        Args:
+            destinations (str): Candidate homeservers which are probably
+                participating in the room.
+            pdu (BaseEvent): event to be sent
+
+        Return:
+            Deferred: resolves to a dict with members ``origin`` (a string
+            giving the serer the event was sent to, ``state`` (?) and
+            ``auth_chain``.
+
+            Fails with a ``CodeMessageException`` if the chosen remote server
+            returns a 300/400 code.
+
+            Fails with a ``RuntimeError`` if no servers were reachable.
+        """
+
         for destination in destinations:
             if destination == self.server_name:
                 continue
@@ -549,12 +577,12 @@ class FederationClient(FederationBase):
                 logger.debug("Got content: %s", content)
 
                 state = [
-                    self.event_from_pdu_json(p, outlier=True)
+                    event_from_pdu_json(p, outlier=True)
                     for p in content.get("state", [])
                 ]
 
                 auth_chain = [
-                    self.event_from_pdu_json(p, outlier=True)
+                    event_from_pdu_json(p, outlier=True)
                     for p in content.get("auth_chain", [])
                 ]
 
@@ -629,7 +657,7 @@ class FederationClient(FederationBase):
 
         logger.debug("Got response to send_invite: %s", pdu_dict)
 
-        pdu = self.event_from_pdu_json(pdu_dict)
+        pdu = event_from_pdu_json(pdu_dict)
 
         # Check signatures are correct.
         pdu = yield self._check_sigs_and_hash(pdu)
@@ -640,6 +668,26 @@ class FederationClient(FederationBase):
 
     @defer.inlineCallbacks
     def send_leave(self, destinations, pdu):
+        """Sends a leave event to one of a list of homeservers.
+
+        Doing so will cause the remote server to add the event to the graph,
+        and send the event out to the rest of the federation.
+
+        This is mostly useful to reject received invites.
+
+        Args:
+            destinations (str): Candidate homeservers which are probably
+                participating in the room.
+            pdu (BaseEvent): event to be sent
+
+        Return:
+            Deferred: resolves to None.
+
+            Fails with a ``CodeMessageException`` if the chosen remote server
+            returns a non-200 code.
+
+            Fails with a ``RuntimeError`` if no servers were reachable.
+        """
         for destination in destinations:
             if destination == self.server_name:
                 continue
@@ -699,7 +747,7 @@ class FederationClient(FederationBase):
         )
 
         auth_chain = [
-            self.event_from_pdu_json(e)
+            event_from_pdu_json(e)
             for e in content["auth_chain"]
         ]
 
@@ -747,7 +795,7 @@ class FederationClient(FederationBase):
             )
 
             events = [
-                self.event_from_pdu_json(e)
+                event_from_pdu_json(e)
                 for e in content.get("events", [])
             ]
 
@@ -764,15 +812,6 @@ class FederationClient(FederationBase):
 
         defer.returnValue(signed_events)
 
-    def event_from_pdu_json(self, pdu_json, outlier=False):
-        event = FrozenEvent(
-            pdu_json
-        )
-
-        event.internal_metadata.outlier = outlier
-
-        return event
-
     @defer.inlineCallbacks
     def forward_third_party_invite(self, destinations, room_id, event_dict):
         for destination in destinations:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index e922b7ff4a..247ddc89d5 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,27 +13,31 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
 
-
+import simplejson as json
 from twisted.internet import defer
 
-from .federation_base import FederationBase
-from .units import Transaction, Edu
+from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError
+from synapse.crypto.event_signing import compute_event_signature
+from synapse.federation.federation_base import (
+    FederationBase,
+    event_from_pdu_json,
+)
 
-from synapse.util.async import Linearizer
-from synapse.util.logutils import log_function
-from synapse.util.caches.response_cache import ResponseCache
-from synapse.events import FrozenEvent
-from synapse.types import get_domain_from_id
+from synapse.federation.persistence import TransactionActions
+from synapse.federation.units import Edu, Transaction
 import synapse.metrics
+from synapse.types import get_domain_from_id
+from synapse.util import async
+from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.logutils import log_function
 
-from synapse.api.errors import AuthError, FederationError, SynapseError
-
-from synapse.crypto.event_signing import compute_event_signature
-
-import simplejson as json
-import logging
+from six import iteritems
 
+# when processing incoming transactions, we try to handle multiple rooms in
+# parallel, up to this limit.
+TRANSACTION_CONCURRENCY_LIMIT = 10
 
 logger = logging.getLogger(__name__)
 
@@ -51,49 +56,18 @@ class FederationServer(FederationBase):
         super(FederationServer, self).__init__(hs)
 
         self.auth = hs.get_auth()
+        self.handler = hs.get_handlers().federation_handler
 
-        self._room_pdu_linearizer = Linearizer("fed_room_pdu")
-        self._server_linearizer = Linearizer("fed_server")
-
-        # We cache responses to state queries, as they take a while and often
-        # come in waves.
-        self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
-
-    def set_handler(self, handler):
-        """Sets the handler that the replication layer will use to communicate
-        receipt of new PDUs from other home servers. The required methods are
-        documented on :py:class:`.ReplicationHandler`.
-        """
-        self.handler = handler
-
-    def register_edu_handler(self, edu_type, handler):
-        if edu_type in self.edu_handlers:
-            raise KeyError("Already have an EDU handler for %s" % (edu_type,))
-
-        self.edu_handlers[edu_type] = handler
-
-    def register_query_handler(self, query_type, handler):
-        """Sets the handler callable that will be used to handle an incoming
-        federation Query of the given type.
-
-        Args:
-            query_type (str): Category name of the query, which should match
-                the string used by make_query.
-            handler (callable): Invoked to handle incoming queries of this type
+        self._server_linearizer = async.Linearizer("fed_server")
+        self._transaction_linearizer = async.Linearizer("fed_txn_handler")
 
-        handler is invoked as:
-            result = handler(args)
+        self.transaction_actions = TransactionActions(self.store)
 
-        where 'args' is a dict mapping strings to strings of the query
-          arguments. It should return a Deferred that will eventually yield an
-          object to encode as JSON.
-        """
-        if query_type in self.query_handlers:
-            raise KeyError(
-                "Already have a Query handler for %s" % (query_type,)
-            )
+        self.registry = hs.get_federation_registry()
 
-        self.query_handlers[query_type] = handler
+        # We cache responses to state queries, as they take a while and often
+        # come in waves.
+        self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
 
     @defer.inlineCallbacks
     @log_function
@@ -110,25 +84,41 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     @log_function
     def on_incoming_transaction(self, transaction_data):
+        # keep this as early as possible to make the calculated origin ts as
+        # accurate as possible.
+        request_time = self._clock.time_msec()
+
         transaction = Transaction(**transaction_data)
 
-        received_pdus_counter.inc_by(len(transaction.pdus))
+        if not transaction.transaction_id:
+            raise Exception("Transaction missing transaction_id")
+        if not transaction.origin:
+            raise Exception("Transaction missing origin")
 
-        for p in transaction.pdus:
-            if "unsigned" in p:
-                unsigned = p["unsigned"]
-                if "age" in unsigned:
-                    p["age"] = unsigned["age"]
-            if "age" in p:
-                p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
-                del p["age"]
+        logger.debug("[%s] Got transaction", transaction.transaction_id)
 
-        pdu_list = [
-            self.event_from_pdu_json(p) for p in transaction.pdus
-        ]
+        # use a linearizer to ensure that we don't process the same transaction
+        # multiple times in parallel.
+        with (yield self._transaction_linearizer.queue(
+                (transaction.origin, transaction.transaction_id),
+        )):
+            result = yield self._handle_incoming_transaction(
+                transaction, request_time,
+            )
 
-        logger.debug("[%s] Got transaction", transaction.transaction_id)
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def _handle_incoming_transaction(self, transaction, request_time):
+        """ Process an incoming transaction and return the HTTP response
 
+        Args:
+            transaction (Transaction): incoming transaction
+            request_time (int): timestamp that the HTTP request arrived at
+
+        Returns:
+            Deferred[(int, object)]: http response code and body
+        """
         response = yield self.transaction_actions.have_responded(transaction)
 
         if response:
@@ -141,38 +131,49 @@ class FederationServer(FederationBase):
 
         logger.debug("[%s] Transaction is new", transaction.transaction_id)
 
-        results = []
-
-        for pdu in pdu_list:
-            # check that it's actually being sent from a valid destination to
-            # workaround bug #1753 in 0.18.5 and 0.18.6
-            if transaction.origin != get_domain_from_id(pdu.event_id):
-                if not (
-                    pdu.type == 'm.room.member' and
-                    pdu.content and
-                    pdu.content.get("membership", None) == 'join' and
-                    self.hs.is_mine_id(pdu.state_key)
-                ):
-                    logger.info(
-                        "Discarding PDU %s from invalid origin %s",
-                        pdu.event_id, transaction.origin
-                    )
-                    continue
-                else:
-                    logger.info(
-                        "Accepting join PDU %s from %s",
-                        pdu.event_id, transaction.origin
-                    )
+        received_pdus_counter.inc_by(len(transaction.pdus))
+
+        pdus_by_room = {}
+
+        for p in transaction.pdus:
+            if "unsigned" in p:
+                unsigned = p["unsigned"]
+                if "age" in unsigned:
+                    p["age"] = unsigned["age"]
+            if "age" in p:
+                p["age_ts"] = request_time - int(p["age"])
+                del p["age"]
 
-            try:
-                yield self._handle_new_pdu(transaction.origin, pdu)
-                results.append({})
-            except FederationError as e:
-                self.send_failure(e, transaction.origin)
-                results.append({"error": str(e)})
-            except Exception as e:
-                results.append({"error": str(e)})
-                logger.exception("Failed to handle PDU")
+            event = event_from_pdu_json(p)
+            room_id = event.room_id
+            pdus_by_room.setdefault(room_id, []).append(event)
+
+        pdu_results = {}
+
+        # we can process different rooms in parallel (which is useful if they
+        # require callouts to other servers to fetch missing events), but
+        # impose a limit to avoid going too crazy with ram/cpu.
+        @defer.inlineCallbacks
+        def process_pdus_for_room(room_id):
+            logger.debug("Processing PDUs for %s", room_id)
+            for pdu in pdus_by_room[room_id]:
+                event_id = pdu.event_id
+                try:
+                    yield self._handle_received_pdu(
+                        transaction.origin, pdu
+                    )
+                    pdu_results[event_id] = {}
+                except FederationError as e:
+                    logger.warn("Error handling PDU %s: %s", event_id, e)
+                    pdu_results[event_id] = {"error": str(e)}
+                except Exception as e:
+                    pdu_results[event_id] = {"error": str(e)}
+                    logger.exception("Failed to handle PDU %s", event_id)
+
+        yield async.concurrently_execute(
+            process_pdus_for_room, pdus_by_room.keys(),
+            TRANSACTION_CONCURRENCY_LIMIT,
+        )
 
         if hasattr(transaction, "edus"):
             for edu in (Edu(**x) for x in transaction.edus):
@@ -182,17 +183,16 @@ class FederationServer(FederationBase):
                     edu.content
                 )
 
-            for failure in getattr(transaction, "pdu_failures", []):
-                logger.info("Got failure %r", failure)
-
-        logger.debug("Returning: %s", str(results))
+        pdu_failures = getattr(transaction, "pdu_failures", [])
+        for failure in pdu_failures:
+            logger.info("Got failure %r", failure)
 
         response = {
-            "pdus": dict(zip(
-                (p.event_id for p in pdu_list), results
-            )),
+            "pdus": pdu_results,
         }
 
+        logger.debug("Returning: %s", str(response))
+
         yield self.transaction_actions.set_response(
             transaction,
             200, response
@@ -202,16 +202,7 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     def received_edu(self, origin, edu_type, content):
         received_edus_counter.inc()
-
-        if edu_type in self.edu_handlers:
-            try:
-                yield self.edu_handlers[edu_type](origin, content)
-            except SynapseError as e:
-                logger.info("Failed to handle edu %r: %r", edu_type, e)
-            except Exception as e:
-                logger.exception("Failed to handle edu %r", edu_type)
-        else:
-            logger.warn("Received EDU of type %s with no handler", edu_type)
+        yield self.registry.on_edu(edu_type, origin, content)
 
     @defer.inlineCallbacks
     @log_function
@@ -223,15 +214,17 @@ class FederationServer(FederationBase):
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
-        result = self._state_resp_cache.get((room_id, event_id))
-        if not result:
-            with (yield self._server_linearizer.queue((origin, room_id))):
-                resp = yield self._state_resp_cache.set(
-                    (room_id, event_id),
-                    self._on_context_state_request_compute(room_id, event_id)
-                )
-        else:
-            resp = yield result
+        # we grab the linearizer to protect ourselves from servers which hammer
+        # us. In theory we might already have the response to this query
+        # in the cache so we could return it without waiting for the linearizer
+        # - but that's non-trivial to get right, and anyway somewhat defeats
+        # the point of the linearizer.
+        with (yield self._server_linearizer.queue((origin, room_id))):
+            resp = yield self._state_resp_cache.wrap(
+                (room_id, event_id),
+                self._on_context_state_request_compute,
+                room_id, event_id,
+            )
 
         defer.returnValue((200, resp))
 
@@ -300,14 +293,8 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     def on_query_request(self, query_type, args):
         received_queries_counter.inc(query_type)
-
-        if query_type in self.query_handlers:
-            response = yield self.query_handlers[query_type](args)
-            defer.returnValue((200, response))
-        else:
-            defer.returnValue(
-                (404, "No handler for Query type '%s'" % (query_type,))
-            )
+        resp = yield self.registry.on_query(query_type, args)
+        defer.returnValue((200, resp))
 
     @defer.inlineCallbacks
     def on_make_join_request(self, room_id, user_id):
@@ -317,7 +304,7 @@ class FederationServer(FederationBase):
 
     @defer.inlineCallbacks
     def on_invite_request(self, origin, content):
-        pdu = self.event_from_pdu_json(content)
+        pdu = event_from_pdu_json(content)
         ret_pdu = yield self.handler.on_invite_request(origin, pdu)
         time_now = self._clock.time_msec()
         defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@@ -325,7 +312,7 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     def on_send_join_request(self, origin, content):
         logger.debug("on_send_join_request: content: %s", content)
-        pdu = self.event_from_pdu_json(content)
+        pdu = event_from_pdu_json(content)
         logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
         res_pdus = yield self.handler.on_send_join_request(origin, pdu)
         time_now = self._clock.time_msec()
@@ -345,7 +332,7 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     def on_send_leave_request(self, origin, content):
         logger.debug("on_send_leave_request: content: %s", content)
-        pdu = self.event_from_pdu_json(content)
+        pdu = event_from_pdu_json(content)
         logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
         yield self.handler.on_send_leave_request(origin, pdu)
         defer.returnValue((200, {}))
@@ -382,7 +369,7 @@ class FederationServer(FederationBase):
         """
         with (yield self._server_linearizer.queue((origin, room_id))):
             auth_chain = [
-                self.event_from_pdu_json(e)
+                event_from_pdu_json(e)
                 for e in content["auth_chain"]
             ]
 
@@ -437,6 +424,16 @@ class FederationServer(FederationBase):
                         key_id: json.loads(json_bytes)
                     }
 
+        logger.info(
+            "Claimed one-time-keys: %s",
+            ",".join((
+                "%s for %s:%s" % (key_id, user_id, device_id)
+                for user_id, user_keys in iteritems(json_result)
+                for device_id, device_keys in iteritems(user_keys)
+                for key_id, _ in iteritems(device_keys)
+            )),
+        )
+
         defer.returnValue({"one_time_keys": json_result})
 
     @defer.inlineCallbacks
@@ -497,26 +494,59 @@ class FederationServer(FederationBase):
         )
 
     @defer.inlineCallbacks
-    @log_function
-    def _handle_new_pdu(self, origin, pdu, get_missing=True):
+    def _handle_received_pdu(self, origin, pdu):
+        """ Process a PDU received in a federation /send/ transaction.
+
+        If the event is invalid, then this method throws a FederationError.
+        (The error will then be logged and sent back to the sender (which
+        probably won't do anything with it), and other events in the
+        transaction will be processed as normal).
+
+        It is likely that we'll then receive other events which refer to
+        this rejected_event in their prev_events, etc.  When that happens,
+        we'll attempt to fetch the rejected event again, which will presumably
+        fail, so those second-generation events will also get rejected.
+
+        Eventually, we get to the point where there are more than 10 events
+        between any new events and the original rejected event. Since we
+        only try to backfill 10 events deep on received pdu, we then accept the
+        new event, possibly introducing a discontinuity in the DAG, with new
+        forward extremities, so normal service is approximately returned,
+        until we try to backfill across the discontinuity.
 
-        # We reprocess pdus when we have seen them only as outliers
-        existing = yield self._get_persisted_pdu(
-            origin, pdu.event_id, do_auth=False
-        )
+        Args:
+            origin (str): server which sent the pdu
+            pdu (FrozenEvent): received pdu
 
-        # FIXME: Currently we fetch an event again when we already have it
-        # if it has been marked as an outlier.
+        Returns (Deferred): completes with None
 
-        already_seen = (
-            existing and (
-                not existing.internal_metadata.is_outlier()
-                or pdu.internal_metadata.is_outlier()
-            )
-        )
-        if already_seen:
-            logger.debug("Already seen pdu %s", pdu.event_id)
-            return
+        Raises: FederationError if the signatures / hash do not match, or
+            if the event was unacceptable for any other reason (eg, too large,
+            too many prev_events, couldn't find the prev_events)
+        """
+        # check that it's actually being sent from a valid destination to
+        # workaround bug #1753 in 0.18.5 and 0.18.6
+        if origin != get_domain_from_id(pdu.event_id):
+            # We continue to accept join events from any server; this is
+            # necessary for the federation join dance to work correctly.
+            # (When we join over federation, the "helper" server is
+            # responsible for sending out the join event, rather than the
+            # origin. See bug #1893).
+            if not (
+                pdu.type == 'm.room.member' and
+                pdu.content and
+                pdu.content.get("membership", None) == 'join'
+            ):
+                logger.info(
+                    "Discarding PDU %s from invalid origin %s",
+                    pdu.event_id, origin
+                )
+                return
+            else:
+                logger.info(
+                    "Accepting join PDU %s from %s",
+                    pdu.event_id, origin
+                )
 
         # Check signature.
         try:
@@ -529,156 +559,11 @@ class FederationServer(FederationBase):
                 affected=pdu.event_id,
             )
 
-        state = None
-
-        auth_chain = []
-
-        have_seen = yield self.store.have_events(
-            [ev for ev, _ in pdu.prev_events]
-        )
-
-        fetch_state = False
-
-        # Get missing pdus if necessary.
-        if not pdu.internal_metadata.is_outlier():
-            # We only backfill backwards to the min depth.
-            min_depth = yield self.handler.get_min_depth_for_context(
-                pdu.room_id
-            )
-
-            logger.debug(
-                "_handle_new_pdu min_depth for %s: %d",
-                pdu.room_id, min_depth
-            )
-
-            prevs = {e_id for e_id, _ in pdu.prev_events}
-            seen = set(have_seen.keys())
-
-            if min_depth and pdu.depth < min_depth:
-                # This is so that we don't notify the user about this
-                # message, to work around the fact that some events will
-                # reference really really old events we really don't want to
-                # send to the clients.
-                pdu.internal_metadata.outlier = True
-            elif min_depth and pdu.depth > min_depth:
-                if get_missing and prevs - seen:
-                    # If we're missing stuff, ensure we only fetch stuff one
-                    # at a time.
-                    logger.info(
-                        "Acquiring lock for room %r to fetch %d missing events: %r...",
-                        pdu.room_id, len(prevs - seen), list(prevs - seen)[:5],
-                    )
-                    with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
-                        logger.info(
-                            "Acquired lock for room %r to fetch %d missing events",
-                            pdu.room_id, len(prevs - seen),
-                        )
-
-                        # We recalculate seen, since it may have changed.
-                        have_seen = yield self.store.have_events(prevs)
-                        seen = set(have_seen.keys())
-
-                        if prevs - seen:
-                            latest = yield self.store.get_latest_event_ids_in_room(
-                                pdu.room_id
-                            )
-
-                            # We add the prev events that we have seen to the latest
-                            # list to ensure the remote server doesn't give them to us
-                            latest = set(latest)
-                            latest |= seen
-
-                            logger.info(
-                                "Missing %d events for room %r: %r...",
-                                len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
-                            )
-
-                            # XXX: we set timeout to 10s to help workaround
-                            # https://github.com/matrix-org/synapse/issues/1733.
-                            # The reason is to avoid holding the linearizer lock
-                            # whilst processing inbound /send transactions, causing
-                            # FDs to stack up and block other inbound transactions
-                            # which empirically can currently take up to 30 minutes.
-                            #
-                            # N.B. this explicitly disables retry attempts.
-                            #
-                            # N.B. this also increases our chances of falling back to
-                            # fetching fresh state for the room if the missing event
-                            # can't be found, which slightly reduces our security.
-                            # it may also increase our DAG extremity count for the room,
-                            # causing additional state resolution?  See #1760.
-                            # However, fetching state doesn't hold the linearizer lock
-                            # apparently.
-                            #
-                            # see https://github.com/matrix-org/synapse/pull/1744
-
-                            missing_events = yield self.get_missing_events(
-                                origin,
-                                pdu.room_id,
-                                earliest_events_ids=list(latest),
-                                latest_events=[pdu],
-                                limit=10,
-                                min_depth=min_depth,
-                                timeout=10000,
-                            )
-
-                            # We want to sort these by depth so we process them and
-                            # tell clients about them in order.
-                            missing_events.sort(key=lambda x: x.depth)
-
-                            for e in missing_events:
-                                yield self._handle_new_pdu(
-                                    origin,
-                                    e,
-                                    get_missing=False
-                                )
-
-                            have_seen = yield self.store.have_events(
-                                [ev for ev, _ in pdu.prev_events]
-                            )
-
-            prevs = {e_id for e_id, _ in pdu.prev_events}
-            seen = set(have_seen.keys())
-            if prevs - seen:
-                logger.info(
-                    "Still missing %d events for room %r: %r...",
-                    len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
-                )
-                fetch_state = True
-
-        if fetch_state:
-            # We need to get the state at this event, since we haven't
-            # processed all the prev events.
-            logger.debug(
-                "_handle_new_pdu getting state for %s",
-                pdu.room_id
-            )
-            try:
-                state, auth_chain = yield self.get_state_for_room(
-                    origin, pdu.room_id, pdu.event_id,
-                )
-            except:
-                logger.exception("Failed to get state for event: %s", pdu.event_id)
-
-        yield self.handler.on_receive_pdu(
-            origin,
-            pdu,
-            state=state,
-            auth_chain=auth_chain,
-        )
+        yield self.handler.on_receive_pdu(origin, pdu, get_missing=True)
 
     def __str__(self):
         return "<ReplicationLayer(%s)>" % self.server_name
 
-    def event_from_pdu_json(self, pdu_json, outlier=False):
-        event = FrozenEvent(
-            pdu_json
-        )
-
-        event.internal_metadata.outlier = outlier
-
-        return event
-
     @defer.inlineCallbacks
     def exchange_third_party_invite(
             self,
@@ -701,3 +586,66 @@ class FederationServer(FederationBase):
             origin, room_id, event_dict
         )
         defer.returnValue(ret)
+
+
+class FederationHandlerRegistry(object):
+    """Allows classes to register themselves as handlers for a given EDU or
+    query type for incoming federation traffic.
+    """
+    def __init__(self):
+        self.edu_handlers = {}
+        self.query_handlers = {}
+
+    def register_edu_handler(self, edu_type, handler):
+        """Sets the handler callable that will be used to handle an incoming
+        federation EDU of the given type.
+
+        Args:
+            edu_type (str): The type of the incoming EDU to register handler for
+            handler (Callable[[str, dict]]): A callable invoked on incoming EDU
+                of the given type. The arguments are the origin server name and
+                the EDU contents.
+        """
+        if edu_type in self.edu_handlers:
+            raise KeyError("Already have an EDU handler for %s" % (edu_type,))
+
+        self.edu_handlers[edu_type] = handler
+
+    def register_query_handler(self, query_type, handler):
+        """Sets the handler callable that will be used to handle an incoming
+        federation query of the given type.
+
+        Args:
+            query_type (str): Category name of the query, which should match
+                the string used by make_query.
+            handler (Callable[[dict], Deferred[dict]]): Invoked to handle
+                incoming queries of this type. The return will be yielded
+                on and the result used as the response to the query request.
+        """
+        if query_type in self.query_handlers:
+            raise KeyError(
+                "Already have a Query handler for %s" % (query_type,)
+            )
+
+        self.query_handlers[query_type] = handler
+
+    @defer.inlineCallbacks
+    def on_edu(self, edu_type, origin, content):
+        handler = self.edu_handlers.get(edu_type)
+        if not handler:
+            logger.warn("No handler registered for EDU type %s", edu_type)
+
+        try:
+            yield handler(origin, content)
+        except SynapseError as e:
+            logger.info("Failed to handle edu %r: %r", edu_type, e)
+        except Exception as e:
+            logger.exception("Failed to handle edu %r", edu_type)
+
+    def on_query(self, query_type, args):
+        handler = self.query_handlers.get(query_type)
+        if not handler:
+            logger.warn("No handler registered for query type %s", query_type)
+            raise NotFoundError("No handler for Query type '%s'" % (query_type,))
+
+        return handler(args)
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
deleted file mode 100644
index 62d865ec4b..0000000000
--- a/synapse/federation/replication.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""This layer is responsible for replicating with remote home servers using
-a given transport.
-"""
-
-from .federation_client import FederationClient
-from .federation_server import FederationServer
-
-from .persistence import TransactionActions
-
-import logging
-
-
-logger = logging.getLogger(__name__)
-
-
-class ReplicationLayer(FederationClient, FederationServer):
-    """This layer is responsible for replicating with remote home servers over
-    the given transport. I.e., does the sending and receiving of PDUs to
-    remote home servers.
-
-    The layer communicates with the rest of the server via a registered
-    ReplicationHandler.
-
-    In more detail, the layer:
-        * Receives incoming data and processes it into transactions and pdus.
-        * Fetches any PDUs it thinks it might have missed.
-        * Keeps the current state for contexts up to date by applying the
-          suitable conflict resolution.
-        * Sends outgoing pdus wrapped in transactions.
-        * Fills out the references to previous pdus/transactions appropriately
-          for outgoing data.
-    """
-
-    def __init__(self, hs, transport_layer):
-        self.server_name = hs.hostname
-
-        self.keyring = hs.get_keyring()
-
-        self.transport_layer = transport_layer
-
-        self.federation_client = self
-
-        self.store = hs.get_datastore()
-
-        self.handler = None
-        self.edu_handlers = {}
-        self.query_handlers = {}
-
-        self._clock = hs.get_clock()
-
-        self.transaction_actions = TransactionActions(self.store)
-
-        self.hs = hs
-
-        super(ReplicationLayer, self).__init__(hs)
-
-    def __str__(self):
-        return "<ReplicationLayer(%s)>" % self.server_name
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 5c9f7a86f0..0f0c687b37 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -31,21 +31,21 @@ Events are replicated via a separate events stream.
 
 from .units import Edu
 
+from synapse.storage.presence import UserPresenceState
 from synapse.util.metrics import Measure
 import synapse.metrics
 
 from blist import sorteddict
-import ujson
+from collections import namedtuple
 
+import logging
 
-metrics = synapse.metrics.get_metrics_for(__name__)
+from six import itervalues, iteritems
+
+logger = logging.getLogger(__name__)
 
 
-PRESENCE_TYPE = "p"
-KEYED_EDU_TYPE = "k"
-EDU_TYPE = "e"
-FAILURE_TYPE = "f"
-DEVICE_MESSAGE_TYPE = "d"
+metrics = synapse.metrics.get_metrics_for(__name__)
 
 
 class FederationRemoteSendQueue(object):
@@ -54,18 +54,20 @@ class FederationRemoteSendQueue(object):
     def __init__(self, hs):
         self.server_name = hs.hostname
         self.clock = hs.get_clock()
+        self.notifier = hs.get_notifier()
+        self.is_mine_id = hs.is_mine_id
 
-        self.presence_map = {}
-        self.presence_changed = sorteddict()
+        self.presence_map = {}  # Pending presence map user_id -> UserPresenceState
+        self.presence_changed = sorteddict()  # Stream position -> user_id
 
-        self.keyed_edu = {}
-        self.keyed_edu_changed = sorteddict()
+        self.keyed_edu = {}  # (destination, key) -> EDU
+        self.keyed_edu_changed = sorteddict()  # stream position -> (destination, key)
 
-        self.edus = sorteddict()
+        self.edus = sorteddict()  # stream position -> Edu
 
-        self.failures = sorteddict()
+        self.failures = sorteddict()  # stream position -> (destination, Failure)
 
-        self.device_messages = sorteddict()
+        self.device_messages = sorteddict()  # stream position -> destination
 
         self.pos = 1
         self.pos_time = sorteddict()
@@ -121,7 +123,9 @@ class FederationRemoteSendQueue(object):
                 del self.presence_changed[key]
 
             user_ids = set(
-                user_id for uids in self.presence_changed.values() for _, user_id in uids
+                user_id
+                for uids in itervalues(self.presence_changed)
+                for user_id in uids
             )
 
             to_del = [
@@ -186,37 +190,50 @@ class FederationRemoteSendQueue(object):
         else:
             self.edus[pos] = edu
 
-    def send_presence(self, destination, states):
-        """As per TransactionQueue"""
+        self.notifier.on_new_replication_data()
+
+    def send_presence(self, states):
+        """As per TransactionQueue
+
+        Args:
+            states (list(UserPresenceState))
+        """
         pos = self._next_pos()
 
-        self.presence_map.update({
-            state.user_id: state
-            for state in states
-        })
+        # We only want to send presence for our own users, so lets always just
+        # filter here just in case.
+        local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
 
-        self.presence_changed[pos] = [
-            (destination, state.user_id) for state in states
-        ]
+        self.presence_map.update({state.user_id: state for state in local_states})
+        self.presence_changed[pos] = [state.user_id for state in local_states]
+
+        self.notifier.on_new_replication_data()
 
     def send_failure(self, failure, destination):
         """As per TransactionQueue"""
         pos = self._next_pos()
 
         self.failures[pos] = (destination, str(failure))
+        self.notifier.on_new_replication_data()
 
     def send_device_messages(self, destination):
         """As per TransactionQueue"""
         pos = self._next_pos()
         self.device_messages[pos] = destination
+        self.notifier.on_new_replication_data()
 
     def get_current_token(self):
         return self.pos - 1
 
-    def get_replication_rows(self, token, limit, federation_ack=None):
-        """
+    def federation_ack(self, token):
+        self._clear_queue_before_pos(token)
+
+    def get_replication_rows(self, from_token, to_token, limit, federation_ack=None):
+        """Get rows to be sent over federation between the two tokens
+
         Args:
-            token (int)
+            from_token (int)
+            to_token(int)
             limit (int)
             federation_ack (int): Optional. The position where the worker is
                 explicitly acknowledged it has handled. Allows us to drop
@@ -225,9 +242,11 @@ class FederationRemoteSendQueue(object):
         # TODO: Handle limit.
 
         # To handle restarts where we wrap around
-        if token > self.pos:
-            token = -1
+        if from_token > self.pos:
+            from_token = -1
 
+        # list of tuple(int, BaseFederationRow), where the first is the position
+        # of the federation stream.
         rows = []
 
         # There should be only one reader, so lets delete everything its
@@ -237,62 +256,295 @@ class FederationRemoteSendQueue(object):
 
         # Fetch changed presence
         keys = self.presence_changed.keys()
-        i = keys.bisect_right(token)
-        dest_user_ids = set(
-            (pos, dest_user_id)
-            for pos in keys[i:]
-            for dest_user_id in self.presence_changed[pos]
-        )
+        i = keys.bisect_right(from_token)
+        j = keys.bisect_right(to_token) + 1
+        dest_user_ids = [
+            (pos, user_id)
+            for pos in keys[i:j]
+            for user_id in self.presence_changed[pos]
+        ]
 
-        for (key, (dest, user_id)) in dest_user_ids:
-            rows.append((key, PRESENCE_TYPE, ujson.dumps({
-                "destination": dest,
-                "state": self.presence_map[user_id].as_dict(),
-            })))
+        for (key, user_id) in dest_user_ids:
+            rows.append((key, PresenceRow(
+                state=self.presence_map[user_id],
+            )))
 
         # Fetch changes keyed edus
         keys = self.keyed_edu_changed.keys()
-        i = keys.bisect_right(token)
-        keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:])
-
-        for (pos, (destination, edu_key)) in keyed_edus:
-            rows.append(
-                (pos, KEYED_EDU_TYPE, ujson.dumps({
-                    "key": edu_key,
-                    "edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(),
-                }))
-            )
+        i = keys.bisect_right(from_token)
+        j = keys.bisect_right(to_token) + 1
+        # We purposefully clobber based on the key here, python dict comprehensions
+        # always use the last value, so this will correctly point to the last
+        # stream position.
+        keyed_edus = {self.keyed_edu_changed[k]: k for k in keys[i:j]}
+
+        for ((destination, edu_key), pos) in iteritems(keyed_edus):
+            rows.append((pos, KeyedEduRow(
+                key=edu_key,
+                edu=self.keyed_edu[(destination, edu_key)],
+            )))
 
         # Fetch changed edus
         keys = self.edus.keys()
-        i = keys.bisect_right(token)
-        edus = set((k, self.edus[k]) for k in keys[i:])
+        i = keys.bisect_right(from_token)
+        j = keys.bisect_right(to_token) + 1
+        edus = ((k, self.edus[k]) for k in keys[i:j])
 
         for (pos, edu) in edus:
-            rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict())))
+            rows.append((pos, EduRow(edu)))
 
         # Fetch changed failures
         keys = self.failures.keys()
-        i = keys.bisect_right(token)
-        failures = set((k, self.failures[k]) for k in keys[i:])
+        i = keys.bisect_right(from_token)
+        j = keys.bisect_right(to_token) + 1
+        failures = ((k, self.failures[k]) for k in keys[i:j])
 
         for (pos, (destination, failure)) in failures:
-            rows.append((pos, FAILURE_TYPE, ujson.dumps({
-                "destination": destination,
-                "failure": failure,
-            })))
+            rows.append((pos, FailureRow(
+                destination=destination,
+                failure=failure,
+            )))
 
         # Fetch changed device messages
         keys = self.device_messages.keys()
-        i = keys.bisect_right(token)
-        device_messages = set((k, self.device_messages[k]) for k in keys[i:])
+        i = keys.bisect_right(from_token)
+        j = keys.bisect_right(to_token) + 1
+        device_messages = {self.device_messages[k]: k for k in keys[i:j]}
 
-        for (pos, destination) in device_messages:
-            rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({
-                "destination": destination,
-            })))
+        for (destination, pos) in iteritems(device_messages):
+            rows.append((pos, DeviceRow(
+                destination=destination,
+            )))
 
         # Sort rows based on pos
         rows.sort()
 
-        return rows
+        return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
+
+
+class BaseFederationRow(object):
+    """Base class for rows to be sent in the federation stream.
+
+    Specifies how to identify, serialize and deserialize the different types.
+    """
+
+    TypeId = None  # Unique string that ids the type. Must be overriden in sub classes.
+
+    @staticmethod
+    def from_data(data):
+        """Parse the data from the federation stream into a row.
+
+        Args:
+            data: The value of ``data`` from FederationStreamRow.data, type
+                depends on the type of stream
+        """
+        raise NotImplementedError()
+
+    def to_data(self):
+        """Serialize this row to be sent over the federation stream.
+
+        Returns:
+            The value to be sent in FederationStreamRow.data. The type depends
+            on the type of stream.
+        """
+        raise NotImplementedError()
+
+    def add_to_buffer(self, buff):
+        """Add this row to the appropriate field in the buffer ready for this
+        to be sent over federation.
+
+        We use a buffer so that we can batch up events that have come in at
+        the same time and send them all at once.
+
+        Args:
+            buff (BufferedToSend)
+        """
+        raise NotImplementedError()
+
+
+class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", (
+    "state",  # UserPresenceState
+))):
+    TypeId = "p"
+
+    @staticmethod
+    def from_data(data):
+        return PresenceRow(
+            state=UserPresenceState.from_dict(data)
+        )
+
+    def to_data(self):
+        return self.state.as_dict()
+
+    def add_to_buffer(self, buff):
+        buff.presence.append(self.state)
+
+
+class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", (
+    "key",  # tuple(str) - the edu key passed to send_edu
+    "edu",  # Edu
+))):
+    """Streams EDUs that have an associated key that is ued to clobber. For example,
+    typing EDUs clobber based on room_id.
+    """
+
+    TypeId = "k"
+
+    @staticmethod
+    def from_data(data):
+        return KeyedEduRow(
+            key=tuple(data["key"]),
+            edu=Edu(**data["edu"]),
+        )
+
+    def to_data(self):
+        return {
+            "key": self.key,
+            "edu": self.edu.get_internal_dict(),
+        }
+
+    def add_to_buffer(self, buff):
+        buff.keyed_edus.setdefault(
+            self.edu.destination, {}
+        )[self.key] = self.edu
+
+
+class EduRow(BaseFederationRow, namedtuple("EduRow", (
+    "edu",  # Edu
+))):
+    """Streams EDUs that don't have keys. See KeyedEduRow
+    """
+    TypeId = "e"
+
+    @staticmethod
+    def from_data(data):
+        return EduRow(Edu(**data))
+
+    def to_data(self):
+        return self.edu.get_internal_dict()
+
+    def add_to_buffer(self, buff):
+        buff.edus.setdefault(self.edu.destination, []).append(self.edu)
+
+
+class FailureRow(BaseFederationRow, namedtuple("FailureRow", (
+    "destination",  # str
+    "failure",
+))):
+    """Streams failures to a remote server. Failures are issued when there was
+    something wrong with a transaction the remote sent us, e.g. it included
+    an event that was invalid.
+    """
+
+    TypeId = "f"
+
+    @staticmethod
+    def from_data(data):
+        return FailureRow(
+            destination=data["destination"],
+            failure=data["failure"],
+        )
+
+    def to_data(self):
+        return {
+            "destination": self.destination,
+            "failure": self.failure,
+        }
+
+    def add_to_buffer(self, buff):
+        buff.failures.setdefault(self.destination, []).append(self.failure)
+
+
+class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
+    "destination",  # str
+))):
+    """Streams the fact that either a) there is pending to device messages for
+    users on the remote, or b) a local users device has changed and needs to
+    be sent to the remote.
+    """
+    TypeId = "d"
+
+    @staticmethod
+    def from_data(data):
+        return DeviceRow(destination=data["destination"])
+
+    def to_data(self):
+        return {"destination": self.destination}
+
+    def add_to_buffer(self, buff):
+        buff.device_destinations.add(self.destination)
+
+
+TypeToRow = {
+    Row.TypeId: Row
+    for Row in (
+        PresenceRow,
+        KeyedEduRow,
+        EduRow,
+        FailureRow,
+        DeviceRow,
+    )
+}
+
+
+ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
+    "presence",  # list(UserPresenceState)
+    "keyed_edus",  # dict of destination -> { key -> Edu }
+    "edus",  # dict of destination -> [Edu]
+    "failures",  # dict of destination -> [failures]
+    "device_destinations",  # set of destinations
+))
+
+
+def process_rows_for_federation(transaction_queue, rows):
+    """Parse a list of rows from the federation stream and put them in the
+    transaction queue ready for sending to the relevant homeservers.
+
+    Args:
+        transaction_queue (TransactionQueue)
+        rows (list(synapse.replication.tcp.streams.FederationStreamRow))
+    """
+
+    # The federation stream contains a bunch of different types of
+    # rows that need to be handled differently. We parse the rows, put
+    # them into the appropriate collection and then send them off.
+
+    buff = ParsedFederationStreamData(
+        presence=[],
+        keyed_edus={},
+        edus={},
+        failures={},
+        device_destinations=set(),
+    )
+
+    # Parse the rows in the stream and add to the buffer
+    for row in rows:
+        if row.type not in TypeToRow:
+            logger.error("Unrecognized federation row type %r", row.type)
+            continue
+
+        RowType = TypeToRow[row.type]
+        parsed_row = RowType.from_data(row.data)
+        parsed_row.add_to_buffer(buff)
+
+    if buff.presence:
+        transaction_queue.send_presence(buff.presence)
+
+    for destination, edu_map in iteritems(buff.keyed_edus):
+        for key, edu in edu_map.items():
+            transaction_queue.send_edu(
+                edu.destination, edu.edu_type, edu.content, key=key,
+            )
+
+    for destination, edu_list in iteritems(buff.edus):
+        for edu in edu_list:
+            transaction_queue.send_edu(
+                edu.destination, edu.edu_type, edu.content, key=None,
+            )
+
+    for destination, failure_list in iteritems(buff.failures):
+        for failure in failure_list:
+            transaction_queue.send_failure(destination, failure)
+
+    for destination in buff.device_destinations:
+        transaction_queue.send_device_messages(destination)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index bb3d9258a6..ded2b1871a 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -12,22 +12,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import datetime
 
 from twisted.internet import defer
 
 from .persistence import TransactionActions
 from .units import Transaction, Edu
 
-from synapse.api.errors import HttpResponseException
+from synapse.api.errors import HttpResponseException, FederationDeniedError
+from synapse.util import logcontext, PreserveLoggingContext
 from synapse.util.async import run_on_reactor
-from synapse.util.logcontext import preserve_context_over_fn
-from synapse.util.retryutils import (
-    get_retry_limiter, NotRetryingDestination,
-)
+from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
 from synapse.util.metrics import measure_func
-from synapse.types import get_domain_from_id
-from synapse.handlers.presence import format_user_presence_state
+from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
 import synapse.metrics
 
 import logging
@@ -43,6 +40,10 @@ sent_pdus_destination_dist = client_metrics.register_distribution(
 )
 sent_edus_counter = client_metrics.register_counter("sent_edus")
 
+sent_transactions_counter = client_metrics.register_counter("sent_transactions")
+
+events_processed_counter = client_metrics.register_counter("events_processed")
+
 
 class TransactionQueue(object):
     """This class makes sure we only have one transaction in flight at
@@ -79,8 +80,18 @@ class TransactionQueue(object):
         # destination -> list of tuple(edu, deferred)
         self.pending_edus_by_dest = edus = {}
 
-        # Presence needs to be separate as we send single aggragate EDUs
+        # Map of user_id -> UserPresenceState for all the pending presence
+        # to be sent out by user_id. Entries here get processed and put in
+        # pending_presence_by_dest
+        self.pending_presence = {}
+
+        # Map of destination -> user_id -> UserPresenceState of pending presence
+        # to be sent to each destinations
         self.pending_presence_by_dest = presence = {}
+
+        # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
+        # based on their key (e.g. typing events by room_id)
+        # Map of destination -> (edu_type, key) -> Edu
         self.pending_edus_keyed_by_dest = edus_keyed = {}
 
         metrics.register_callback(
@@ -99,7 +110,12 @@ class TransactionQueue(object):
         # destination -> list of tuple(failure, deferred)
         self.pending_failures_by_dest = {}
 
+        # destination -> stream_id of last successfully sent to-device message.
+        # NB: may be a long or an int.
         self.last_device_stream_id_by_dest = {}
+
+        # destination -> stream_id of last successfully sent device list
+        # update.
         self.last_device_list_stream_id_by_dest = {}
 
         # HACK to get unique tx id
@@ -110,6 +126,8 @@ class TransactionQueue(object):
         self._is_processing = False
         self._last_poked_id = -1
 
+        self._processing_pending_presence = False
+
     def can_send_to(self, destination):
         """Can we send messages to the given server?
 
@@ -130,7 +148,6 @@ class TransactionQueue(object):
         else:
             return not destination.startswith("localhost")
 
-    @defer.inlineCallbacks
     def notify_new_events(self, current_id):
         """This gets called when we have some new events we might want to
         send out to other servers.
@@ -140,12 +157,19 @@ class TransactionQueue(object):
         if self._is_processing:
             return
 
+        # fire off a processing loop in the background. It's likely it will
+        # outlast the current request, so run it in the sentinel logcontext.
+        with PreserveLoggingContext():
+            self._process_event_queue_loop()
+
+    @defer.inlineCallbacks
+    def _process_event_queue_loop(self):
         try:
             self._is_processing = True
             while True:
                 last_token = yield self.store.get_federation_out_pos("events")
                 next_token, events = yield self.store.get_all_new_events_stream(
-                    last_token, self._last_poked_id, limit=20,
+                    last_token, self._last_poked_id, limit=100,
                 )
 
                 logger.debug("Handling %s -> %s", last_token, next_token)
@@ -153,28 +177,35 @@ class TransactionQueue(object):
                 if not events and next_token >= self._last_poked_id:
                     break
 
-                for event in events:
+                @defer.inlineCallbacks
+                def handle_event(event):
                     # Only send events for this server.
                     send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of()
                     is_mine = self.is_mine_id(event.event_id)
                     if not is_mine and send_on_behalf_of is None:
-                        continue
-
-                    # Get the state from before the event.
-                    # We need to make sure that this is the state from before
-                    # the event and not from after it.
-                    # Otherwise if the last member on a server in a room is
-                    # banned then it won't receive the event because it won't
-                    # be in the room after the ban.
-                    users_in_room = yield self.state.get_current_user_in_room(
-                        event.room_id, latest_event_ids=[
-                            prev_id for prev_id, _ in event.prev_events
-                        ],
-                    )
+                        return
+
+                    try:
+                        # Get the state from before the event.
+                        # We need to make sure that this is the state from before
+                        # the event and not from after it.
+                        # Otherwise if the last member on a server in a room is
+                        # banned then it won't receive the event because it won't
+                        # be in the room after the ban.
+                        destinations = yield self.state.get_current_hosts_in_room(
+                            event.room_id, latest_event_ids=[
+                                prev_id for prev_id, _ in event.prev_events
+                            ],
+                        )
+                    except Exception:
+                        logger.exception(
+                            "Failed to calculate hosts in room for event: %s",
+                            event.event_id,
+                        )
+                        return
+
+                    destinations = set(destinations)
 
-                    destinations = set(
-                        get_domain_from_id(user_id) for user_id in users_in_room
-                    )
                     if send_on_behalf_of is not None:
                         # If we are sending the event on behalf of another server
                         # then it already has the event and there is no reason to
@@ -185,10 +216,44 @@ class TransactionQueue(object):
 
                     self._send_pdu(event, destinations)
 
+                @defer.inlineCallbacks
+                def handle_room_events(events):
+                    for event in events:
+                        yield handle_event(event)
+
+                events_by_room = {}
+                for event in events:
+                    events_by_room.setdefault(event.room_id, []).append(event)
+
+                yield logcontext.make_deferred_yieldable(defer.gatherResults(
+                    [
+                        logcontext.run_in_background(handle_room_events, evs)
+                        for evs in events_by_room.itervalues()
+                    ],
+                    consumeErrors=True
+                ))
+
                 yield self.store.update_federation_out_pos(
                     "events", next_token
                 )
 
+                if events:
+                    now = self.clock.time_msec()
+                    ts = yield self.store.get_received_ts(events[-1].event_id)
+
+                    synapse.metrics.event_processing_lag.set(
+                        now - ts, "federation_sender",
+                    )
+                    synapse.metrics.event_processing_last_ts.set(
+                        ts, "federation_sender",
+                    )
+
+                events_processed_counter.inc_by(len(events))
+
+                synapse.metrics.event_processing_positions.set(
+                    next_token, "federation_sender",
+                )
+
         finally:
             self._is_processing = False
 
@@ -217,21 +282,75 @@ class TransactionQueue(object):
                 (pdu, order)
             )
 
-            preserve_context_over_fn(
-                self._attempt_new_transaction, destination
-            )
+            self._attempt_new_transaction(destination)
 
-    def send_presence(self, destination, states):
-        if not self.can_send_to(destination):
-            return
+    @logcontext.preserve_fn  # the caller should not yield on this
+    @defer.inlineCallbacks
+    def send_presence(self, states):
+        """Send the new presence states to the appropriate destinations.
+
+        This actually queues up the presence states ready for sending and
+        triggers a background task to process them and send out the transactions.
+
+        Args:
+            states (list(UserPresenceState))
+        """
 
-        self.pending_presence_by_dest.setdefault(destination, {}).update({
+        # First we queue up the new presence by user ID, so multiple presence
+        # updates in quick successtion are correctly handled
+        # We only want to send presence for our own users, so lets always just
+        # filter here just in case.
+        self.pending_presence.update({
             state.user_id: state for state in states
+            if self.is_mine_id(state.user_id)
         })
 
-        preserve_context_over_fn(
-            self._attempt_new_transaction, destination
-        )
+        # We then handle the new pending presence in batches, first figuring
+        # out the destinations we need to send each state to and then poking it
+        # to attempt a new transaction. We linearize this so that we don't
+        # accidentally mess up the ordering and send multiple presence updates
+        # in the wrong order
+        if self._processing_pending_presence:
+            return
+
+        self._processing_pending_presence = True
+        try:
+            while True:
+                states_map = self.pending_presence
+                self.pending_presence = {}
+
+                if not states_map:
+                    break
+
+                yield self._process_presence_inner(states_map.values())
+        except Exception:
+            logger.exception("Error sending presence states to servers")
+        finally:
+            self._processing_pending_presence = False
+
+    @measure_func("txnqueue._process_presence")
+    @defer.inlineCallbacks
+    def _process_presence_inner(self, states):
+        """Given a list of states populate self.pending_presence_by_dest and
+        poke to send a new transaction to each destination
+
+        Args:
+            states (list(UserPresenceState))
+        """
+        hosts_and_states = yield get_interested_remotes(self.store, states, self.state)
+
+        for destinations, states in hosts_and_states:
+            for destination in destinations:
+                if not self.can_send_to(destination):
+                    continue
+
+                self.pending_presence_by_dest.setdefault(
+                    destination, {}
+                ).update({
+                    state.user_id: state for state in states
+                })
+
+                self._attempt_new_transaction(destination)
 
     def send_edu(self, destination, edu_type, content, key=None):
         edu = Edu(
@@ -253,9 +372,7 @@ class TransactionQueue(object):
         else:
             self.pending_edus_by_dest.setdefault(destination, []).append(edu)
 
-        preserve_context_over_fn(
-            self._attempt_new_transaction, destination
-        )
+        self._attempt_new_transaction(destination)
 
     def send_failure(self, failure, destination):
         if destination == self.server_name or destination == "localhost":
@@ -268,9 +385,7 @@ class TransactionQueue(object):
             destination, []
         ).append(failure)
 
-        preserve_context_over_fn(
-            self._attempt_new_transaction, destination
-        )
+        self._attempt_new_transaction(destination)
 
     def send_device_messages(self, destination):
         if destination == self.server_name or destination == "localhost":
@@ -279,15 +394,24 @@ class TransactionQueue(object):
         if not self.can_send_to(destination):
             return
 
-        preserve_context_over_fn(
-            self._attempt_new_transaction, destination
-        )
+        self._attempt_new_transaction(destination)
 
     def get_current_token(self):
         return 0
 
-    @defer.inlineCallbacks
     def _attempt_new_transaction(self, destination):
+        """Try to start a new transaction to this destination
+
+        If there is already a transaction in progress to this destination,
+        returns immediately. Otherwise kicks off the process of sending a
+        transaction in the background.
+
+        Args:
+            destination (str):
+
+        Returns:
+            None
+        """
         # list of (pending_pdu, deferred, order)
         if destination in self.pending_transactions:
             # XXX: pending_transactions can get stuck on by a never-ending
@@ -300,12 +424,46 @@ class TransactionQueue(object):
             )
             return
 
+        logger.debug("TX [%s] Starting transaction loop", destination)
+
+        # Drop the logcontext before starting the transaction. It doesn't
+        # really make sense to log all the outbound transactions against
+        # whatever path led us to this point: that's pretty arbitrary really.
+        #
+        # (this also means we can fire off _perform_transaction without
+        # yielding)
+        with logcontext.PreserveLoggingContext():
+            self._transaction_transmission_loop(destination)
+
+    @defer.inlineCallbacks
+    def _transaction_transmission_loop(self, destination):
+        pending_pdus = []
         try:
             self.pending_transactions[destination] = 1
 
+            # This will throw if we wouldn't retry. We do this here so we fail
+            # quickly, but we will later check this again in the http client,
+            # hence why we throw the result away.
+            yield get_retry_limiter(destination, self.clock, self.store)
+
+            # XXX: what's this for?
             yield run_on_reactor()
 
+            pending_pdus = []
             while True:
+                device_message_edus, device_stream_id, dev_list_id = (
+                    yield self._get_new_device_messages(destination)
+                )
+
+                # BEGIN CRITICAL SECTION
+                #
+                # In order to avoid a race condition, we need to make sure that
+                # the following code (from popping the queues up to the point
+                # where we decide if we actually have any pending messages) is
+                # atomic - otherwise new PDUs or EDUs might arrive in the
+                # meantime, but not get sent because we hold the
+                # pending_transactions flag.
+
                 pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
                 pending_edus = self.pending_edus_by_dest.pop(destination, [])
                 pending_presence = self.pending_presence_by_dest.pop(destination, {})
@@ -315,17 +473,6 @@ class TransactionQueue(object):
                     self.pending_edus_keyed_by_dest.pop(destination, {}).values()
                 )
 
-                limiter = yield get_retry_limiter(
-                    destination,
-                    self.clock,
-                    self.store,
-                    backoff_on_404=True,  # If we get a 404 the other side has gone
-                )
-
-                device_message_edus, device_stream_id, dev_list_id = (
-                    yield self._get_new_device_messages(destination)
-                )
-
                 pending_edus.extend(device_message_edus)
                 if pending_presence:
                     pending_edus.append(
@@ -355,11 +502,13 @@ class TransactionQueue(object):
                     )
                     return
 
+                # END CRITICAL SECTION
+
                 success = yield self._send_new_transaction(
                     destination, pending_pdus, pending_edus, pending_failures,
-                    limiter=limiter,
                 )
                 if success:
+                    sent_transactions_counter.inc()
                     # Remove the acknowledged device messages from the database
                     # Only bother if we actually sent some device messages
                     if device_message_edus:
@@ -375,12 +524,26 @@ class TransactionQueue(object):
                     self.last_device_list_stream_id_by_dest[destination] = dev_list_id
                 else:
                     break
-        except NotRetryingDestination:
+        except NotRetryingDestination as e:
             logger.debug(
-                "TX [%s] not ready for retry yet - "
+                "TX [%s] not ready for retry yet (next retry at %s) - "
                 "dropping transaction for now",
                 destination,
+                datetime.datetime.fromtimestamp(
+                    (e.retry_last_ts + e.retry_interval) / 1000.0
+                ),
             )
+        except FederationDeniedError as e:
+            logger.info(e)
+        except Exception as e:
+            logger.warn(
+                "TX [%s] Failed to send transaction: %s",
+                destination,
+                e,
+            )
+            for p, _ in pending_pdus:
+                logger.info("Failed to send event %s to %s", p.event_id,
+                            destination)
         finally:
             # We want to be *very* sure we delete this after we stop processing
             self.pending_transactions.pop(destination, None)
@@ -420,7 +583,7 @@ class TransactionQueue(object):
     @measure_func("_send_new_transaction")
     @defer.inlineCallbacks
     def _send_new_transaction(self, destination, pending_pdus, pending_edus,
-                              pending_failures, limiter):
+                              pending_failures):
 
         # Sort based on the order field
         pending_pdus.sort(key=lambda t: t[1])
@@ -430,132 +593,104 @@ class TransactionQueue(object):
 
         success = True
 
-        try:
-            logger.debug("TX [%s] _attempt_new_transaction", destination)
+        logger.debug("TX [%s] _attempt_new_transaction", destination)
 
-            txn_id = str(self._next_txn_id)
+        txn_id = str(self._next_txn_id)
 
-            logger.debug(
-                "TX [%s] {%s} Attempting new transaction"
-                " (pdus: %d, edus: %d, failures: %d)",
-                destination, txn_id,
-                len(pdus),
-                len(edus),
-                len(failures)
-            )
+        logger.debug(
+            "TX [%s] {%s} Attempting new transaction"
+            " (pdus: %d, edus: %d, failures: %d)",
+            destination, txn_id,
+            len(pdus),
+            len(edus),
+            len(failures)
+        )
 
-            logger.debug("TX [%s] Persisting transaction...", destination)
+        logger.debug("TX [%s] Persisting transaction...", destination)
 
-            transaction = Transaction.create_new(
-                origin_server_ts=int(self.clock.time_msec()),
-                transaction_id=txn_id,
-                origin=self.server_name,
-                destination=destination,
-                pdus=pdus,
-                edus=edus,
-                pdu_failures=failures,
-            )
+        transaction = Transaction.create_new(
+            origin_server_ts=int(self.clock.time_msec()),
+            transaction_id=txn_id,
+            origin=self.server_name,
+            destination=destination,
+            pdus=pdus,
+            edus=edus,
+            pdu_failures=failures,
+        )
 
-            self._next_txn_id += 1
+        self._next_txn_id += 1
 
-            yield self.transaction_actions.prepare_to_send(transaction)
+        yield self.transaction_actions.prepare_to_send(transaction)
 
-            logger.debug("TX [%s] Persisted transaction", destination)
-            logger.info(
-                "TX [%s] {%s} Sending transaction [%s],"
-                " (PDUs: %d, EDUs: %d, failures: %d)",
-                destination, txn_id,
-                transaction.transaction_id,
-                len(pdus),
-                len(edus),
-                len(failures),
-            )
+        logger.debug("TX [%s] Persisted transaction", destination)
+        logger.info(
+            "TX [%s] {%s} Sending transaction [%s],"
+            " (PDUs: %d, EDUs: %d, failures: %d)",
+            destination, txn_id,
+            transaction.transaction_id,
+            len(pdus),
+            len(edus),
+            len(failures),
+        )
 
-            with limiter:
-                # Actually send the transaction
-
-                # FIXME (erikj): This is a bit of a hack to make the Pdu age
-                # keys work
-                def json_data_cb():
-                    data = transaction.get_dict()
-                    now = int(self.clock.time_msec())
-                    if "pdus" in data:
-                        for p in data["pdus"]:
-                            if "age_ts" in p:
-                                unsigned = p.setdefault("unsigned", {})
-                                unsigned["age"] = now - int(p["age_ts"])
-                                del p["age_ts"]
-                    return data
-
-                try:
-                    response = yield self.transport_layer.send_transaction(
-                        transaction, json_data_cb
-                    )
-                    code = 200
-
-                    if response:
-                        for e_id, r in response.get("pdus", {}).items():
-                            if "error" in r:
-                                logger.warn(
-                                    "Transaction returned error for %s: %s",
-                                    e_id, r,
-                                )
-                except HttpResponseException as e:
-                    code = e.code
-                    response = e.response
-
-                    if e.code in (401, 404, 429) or 500 <= e.code:
-                        logger.info(
-                            "TX [%s] {%s} got %d response",
-                            destination, txn_id, code
+        # Actually send the transaction
+
+        # FIXME (erikj): This is a bit of a hack to make the Pdu age
+        # keys work
+        def json_data_cb():
+            data = transaction.get_dict()
+            now = int(self.clock.time_msec())
+            if "pdus" in data:
+                for p in data["pdus"]:
+                    if "age_ts" in p:
+                        unsigned = p.setdefault("unsigned", {})
+                        unsigned["age"] = now - int(p["age_ts"])
+                        del p["age_ts"]
+            return data
+
+        try:
+            response = yield self.transport_layer.send_transaction(
+                transaction, json_data_cb
+            )
+            code = 200
+
+            if response:
+                for e_id, r in response.get("pdus", {}).items():
+                    if "error" in r:
+                        logger.warn(
+                            "Transaction returned error for %s: %s",
+                            e_id, r,
                         )
-                        raise e
+        except HttpResponseException as e:
+            code = e.code
+            response = e.response
 
+            if e.code in (401, 404, 429) or 500 <= e.code:
                 logger.info(
                     "TX [%s] {%s} got %d response",
                     destination, txn_id, code
                 )
+                raise e
 
-                logger.debug("TX [%s] Sent transaction", destination)
-                logger.debug("TX [%s] Marking as delivered...", destination)
-
-            yield self.transaction_actions.delivered(
-                transaction, code, response
-            )
+        logger.info(
+            "TX [%s] {%s} got %d response",
+            destination, txn_id, code
+        )
 
-            logger.debug("TX [%s] Marked as delivered", destination)
+        logger.debug("TX [%s] Sent transaction", destination)
+        logger.debug("TX [%s] Marking as delivered...", destination)
 
-            if code != 200:
-                for p in pdus:
-                    logger.info(
-                        "Failed to send event %s to %s", p.event_id, destination
-                    )
-                success = False
-        except RuntimeError as e:
-            # We capture this here as there as nothing actually listens
-            # for this finishing functions deferred.
-            logger.warn(
-                "TX [%s] Problem in _attempt_transaction: %s",
-                destination,
-                e,
-            )
+        yield self.transaction_actions.delivered(
+            transaction, code, response
+        )
 
-            success = False
+        logger.debug("TX [%s] Marked as delivered", destination)
 
+        if code != 200:
             for p in pdus:
-                logger.info("Failed to send event %s to %s", p.event_id, destination)
-        except Exception as e:
-            # We capture this here as there as nothing actually listens
-            # for this finishing functions deferred.
-            logger.warn(
-                "TX [%s] Problem in _attempt_transaction: %s",
-                destination,
-                e,
-            )
-
+                logger.info(
+                    "Failed to send event %s to %s", p.event_id, destination
+                )
             success = False
 
-            for p in pdus:
-                logger.info("Failed to send event %s to %s", p.event_id, destination)
-
         defer.returnValue(success)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index f49e8a2cc4..6db8efa6dd 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -20,6 +21,7 @@ from synapse.api.urls import FEDERATION_PREFIX as PREFIX
 from synapse.util.logutils import log_function
 
 import logging
+import urllib
 
 
 logger = logging.getLogger(__name__)
@@ -49,7 +51,7 @@ class TransportLayerClient(object):
         logger.debug("get_room_state dest=%s, room=%s",
                      destination, room_id)
 
-        path = PREFIX + "/state/%s/" % room_id
+        path = _create_path(PREFIX, "/state/%s/", room_id)
         return self.client.get_json(
             destination, path=path, args={"event_id": event_id},
         )
@@ -71,7 +73,7 @@ class TransportLayerClient(object):
         logger.debug("get_room_state_ids dest=%s, room=%s",
                      destination, room_id)
 
-        path = PREFIX + "/state_ids/%s/" % room_id
+        path = _create_path(PREFIX, "/state_ids/%s/", room_id)
         return self.client.get_json(
             destination, path=path, args={"event_id": event_id},
         )
@@ -93,7 +95,7 @@ class TransportLayerClient(object):
         logger.debug("get_pdu dest=%s, event_id=%s",
                      destination, event_id)
 
-        path = PREFIX + "/event/%s/" % (event_id, )
+        path = _create_path(PREFIX, "/event/%s/", event_id)
         return self.client.get_json(destination, path=path, timeout=timeout)
 
     @log_function
@@ -119,7 +121,7 @@ class TransportLayerClient(object):
             # TODO: raise?
             return
 
-        path = PREFIX + "/backfill/%s/" % (room_id,)
+        path = _create_path(PREFIX, "/backfill/%s/", room_id)
 
         args = {
             "v": event_tuples,
@@ -157,12 +159,15 @@ class TransportLayerClient(object):
         # generated by the json_data_callback.
         json_data = transaction.get_dict()
 
+        path = _create_path(PREFIX, "/send/%s/", transaction.transaction_id)
+
         response = yield self.client.put_json(
             transaction.destination,
-            path=PREFIX + "/send/%s/" % transaction.transaction_id,
+            path=path,
             data=json_data,
             json_data_callback=json_data_callback,
             long_retries=True,
+            backoff_on_404=True,  # If we get a 404 the other side has gone
         )
 
         logger.debug(
@@ -174,8 +179,9 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
-    def make_query(self, destination, query_type, args, retry_on_dns_fail):
-        path = PREFIX + "/query/%s" % query_type
+    def make_query(self, destination, query_type, args, retry_on_dns_fail,
+                   ignore_backoff=False):
+        path = _create_path(PREFIX, "/query/%s", query_type)
 
         content = yield self.client.get_json(
             destination=destination,
@@ -183,6 +189,7 @@ class TransportLayerClient(object):
             args=args,
             retry_on_dns_fail=retry_on_dns_fail,
             timeout=10000,
+            ignore_backoff=ignore_backoff,
         )
 
         defer.returnValue(content)
@@ -190,19 +197,54 @@ class TransportLayerClient(object):
     @defer.inlineCallbacks
     @log_function
     def make_membership_event(self, destination, room_id, user_id, membership):
+        """Asks a remote server to build and sign us a membership event
+
+        Note that this does not append any events to any graphs.
+
+        Args:
+            destination (str): address of remote homeserver
+            room_id (str): room to join/leave
+            user_id (str): user to be joined/left
+            membership (str): one of join/leave
+
+        Returns:
+            Deferred: Succeeds when we get a 2xx HTTP response. The result
+            will be the decoded JSON body (ie, the new event).
+
+            Fails with ``HTTPRequestException`` if we get an HTTP response
+            code >= 300.
+
+            Fails with ``NotRetryingDestination`` if we are not yet ready
+            to retry this server.
+
+            Fails with ``FederationDeniedError`` if the remote destination
+            is not in our federation whitelist
+        """
         valid_memberships = {Membership.JOIN, Membership.LEAVE}
         if membership not in valid_memberships:
             raise RuntimeError(
                 "make_membership_event called with membership='%s', must be one of %s" %
                 (membership, ",".join(valid_memberships))
             )
-        path = PREFIX + "/make_%s/%s/%s" % (membership, room_id, user_id)
+        path = _create_path(PREFIX, "/make_%s/%s/%s", membership, room_id, user_id)
+
+        ignore_backoff = False
+        retry_on_dns_fail = False
+
+        if membership == Membership.LEAVE:
+            # we particularly want to do our best to send leave events. The
+            # problem is that if it fails, we won't retry it later, so if the
+            # remote server was just having a momentary blip, the room will be
+            # out of sync.
+            ignore_backoff = True
+            retry_on_dns_fail = True
 
         content = yield self.client.get_json(
             destination=destination,
             path=path,
-            retry_on_dns_fail=False,
+            retry_on_dns_fail=retry_on_dns_fail,
             timeout=20000,
+            ignore_backoff=ignore_backoff,
         )
 
         defer.returnValue(content)
@@ -210,7 +252,7 @@ class TransportLayerClient(object):
     @defer.inlineCallbacks
     @log_function
     def send_join(self, destination, room_id, event_id, content):
-        path = PREFIX + "/send_join/%s/%s" % (room_id, event_id)
+        path = _create_path(PREFIX, "/send_join/%s/%s", room_id, event_id)
 
         response = yield self.client.put_json(
             destination=destination,
@@ -223,12 +265,18 @@ class TransportLayerClient(object):
     @defer.inlineCallbacks
     @log_function
     def send_leave(self, destination, room_id, event_id, content):
-        path = PREFIX + "/send_leave/%s/%s" % (room_id, event_id)
+        path = _create_path(PREFIX, "/send_leave/%s/%s", room_id, event_id)
 
         response = yield self.client.put_json(
             destination=destination,
             path=path,
             data=content,
+
+            # we want to do our best to send this through. The problem is
+            # that if it fails, we won't retry it later, so if the remote
+            # server was just having a momentary blip, the room will be out of
+            # sync.
+            ignore_backoff=True,
         )
 
         defer.returnValue(response)
@@ -236,12 +284,13 @@ class TransportLayerClient(object):
     @defer.inlineCallbacks
     @log_function
     def send_invite(self, destination, room_id, event_id, content):
-        path = PREFIX + "/invite/%s/%s" % (room_id, event_id)
+        path = _create_path(PREFIX, "/invite/%s/%s", room_id, event_id)
 
         response = yield self.client.put_json(
             destination=destination,
             path=path,
             data=content,
+            ignore_backoff=True,
         )
 
         defer.returnValue(response)
@@ -269,6 +318,7 @@ class TransportLayerClient(object):
             destination=remote_server,
             path=path,
             args=args,
+            ignore_backoff=True,
         )
 
         defer.returnValue(response)
@@ -276,7 +326,7 @@ class TransportLayerClient(object):
     @defer.inlineCallbacks
     @log_function
     def exchange_third_party_invite(self, destination, room_id, event_dict):
-        path = PREFIX + "/exchange_third_party_invite/%s" % (room_id,)
+        path = _create_path(PREFIX, "/exchange_third_party_invite/%s", room_id,)
 
         response = yield self.client.put_json(
             destination=destination,
@@ -289,7 +339,7 @@ class TransportLayerClient(object):
     @defer.inlineCallbacks
     @log_function
     def get_event_auth(self, destination, room_id, event_id):
-        path = PREFIX + "/event_auth/%s/%s" % (room_id, event_id)
+        path = _create_path(PREFIX, "/event_auth/%s/%s", room_id, event_id)
 
         content = yield self.client.get_json(
             destination=destination,
@@ -301,7 +351,7 @@ class TransportLayerClient(object):
     @defer.inlineCallbacks
     @log_function
     def send_query_auth(self, destination, room_id, event_id, content):
-        path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id)
+        path = _create_path(PREFIX, "/query_auth/%s/%s", room_id, event_id)
 
         content = yield self.client.post_json(
             destination=destination,
@@ -363,7 +413,7 @@ class TransportLayerClient(object):
         Returns:
             A dict containg the device keys.
         """
-        path = PREFIX + "/user/devices/" + user_id
+        path = _create_path(PREFIX, "/user/devices/%s", user_id)
 
         content = yield self.client.get_json(
             destination=destination,
@@ -413,7 +463,7 @@ class TransportLayerClient(object):
     @log_function
     def get_missing_events(self, destination, room_id, earliest_events,
                            latest_events, limit, min_depth, timeout):
-        path = PREFIX + "/get_missing_events/%s" % (room_id,)
+        path = _create_path(PREFIX, "/get_missing_events/%s", room_id,)
 
         content = yield self.client.post_json(
             destination=destination,
@@ -428,3 +478,475 @@ class TransportLayerClient(object):
         )
 
         defer.returnValue(content)
+
+    @log_function
+    def get_group_profile(self, destination, group_id, requester_user_id):
+        """Get a group profile
+        """
+        path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
+
+        return self.client.get_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def update_group_profile(self, destination, group_id, requester_user_id, content):
+        """Update a remote group profile
+
+        Args:
+            destination (str)
+            group_id (str)
+            requester_user_id (str)
+            content (dict): The new profile of the group
+        """
+        path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def get_group_summary(self, destination, group_id, requester_user_id):
+        """Get a group summary
+        """
+        path = _create_path(PREFIX, "/groups/%s/summary", group_id,)
+
+        return self.client.get_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def get_rooms_in_group(self, destination, group_id, requester_user_id):
+        """Get all rooms in a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/rooms", group_id,)
+
+        return self.client.get_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    def add_room_to_group(self, destination, group_id, requester_user_id, room_id,
+                          content):
+        """Add a room to a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            data=content,
+            ignore_backoff=True,
+        )
+
+    def update_room_in_group(self, destination, group_id, requester_user_id, room_id,
+                             config_key, content):
+        """Update room in group
+        """
+        path = _create_path(
+            PREFIX, "/groups/%s/room/%s/config/%s",
+            group_id, room_id, config_key,
+        )
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            data=content,
+            ignore_backoff=True,
+        )
+
+    def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
+        """Remove a room from a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
+
+        return self.client.delete_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def get_users_in_group(self, destination, group_id, requester_user_id):
+        """Get users in a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/users", group_id,)
+
+        return self.client.get_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def get_invited_users_in_group(self, destination, group_id, requester_user_id):
+        """Get users that have been invited to a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/invited_users", group_id,)
+
+        return self.client.get_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def accept_group_invite(self, destination, group_id, user_id, content):
+        """Accept a group invite
+        """
+        path = _create_path(
+            PREFIX, "/groups/%s/users/%s/accept_invite",
+            group_id, user_id,
+        )
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def join_group(self, destination, group_id, user_id, content):
+        """Attempts to join a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/users/%s/join", group_id, user_id)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
+        """Invite a user to a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/users/%s/invite", group_id, user_id)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def invite_to_group_notification(self, destination, group_id, user_id, content):
+        """Sent by group server to inform a user's server that they have been
+        invited.
+        """
+
+        path = _create_path(PREFIX, "/groups/local/%s/users/%s/invite", group_id, user_id)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def remove_user_from_group(self, destination, group_id, requester_user_id,
+                               user_id, content):
+        """Remove a user fron a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/users/%s/remove", group_id, user_id)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def remove_user_from_group_notification(self, destination, group_id, user_id,
+                                            content):
+        """Sent by group server to inform a user's server that they have been
+        kicked from the group.
+        """
+
+        path = _create_path(PREFIX, "/groups/local/%s/users/%s/remove", group_id, user_id)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def renew_group_attestation(self, destination, group_id, user_id, content):
+        """Sent by either a group server or a user's server to periodically update
+        the attestations
+        """
+
+        path = _create_path(PREFIX, "/groups/%s/renew_attestation/%s", group_id, user_id)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def update_group_summary_room(self, destination, group_id, user_id, room_id,
+                                  category_id, content):
+        """Update a room entry in a group summary
+        """
+        if category_id:
+            path = _create_path(
+                PREFIX, "/groups/%s/summary/categories/%s/rooms/%s",
+                group_id, category_id, room_id,
+            )
+        else:
+            path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": user_id},
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def delete_group_summary_room(self, destination, group_id, user_id, room_id,
+                                  category_id):
+        """Delete a room entry in a group summary
+        """
+        if category_id:
+            path = _create_path(
+                PREFIX + "/groups/%s/summary/categories/%s/rooms/%s",
+                group_id, category_id, room_id,
+            )
+        else:
+            path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
+
+        return self.client.delete_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def get_group_categories(self, destination, group_id, requester_user_id):
+        """Get all categories in a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/categories", group_id,)
+
+        return self.client.get_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def get_group_category(self, destination, group_id, requester_user_id, category_id):
+        """Get category info in a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
+
+        return self.client.get_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def update_group_category(self, destination, group_id, requester_user_id, category_id,
+                              content):
+        """Update a category in a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def delete_group_category(self, destination, group_id, requester_user_id,
+                              category_id):
+        """Delete a category in a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
+
+        return self.client.delete_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def get_group_roles(self, destination, group_id, requester_user_id):
+        """Get all roles in a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/roles", group_id,)
+
+        return self.client.get_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def get_group_role(self, destination, group_id, requester_user_id, role_id):
+        """Get a roles info
+        """
+        path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
+
+        return self.client.get_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def update_group_role(self, destination, group_id, requester_user_id, role_id,
+                          content):
+        """Update a role in a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def delete_group_role(self, destination, group_id, requester_user_id, role_id):
+        """Delete a role in a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
+
+        return self.client.delete_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def update_group_summary_user(self, destination, group_id, requester_user_id,
+                                  user_id, role_id, content):
+        """Update a users entry in a group
+        """
+        if role_id:
+            path = _create_path(
+                PREFIX, "/groups/%s/summary/roles/%s/users/%s",
+                group_id, role_id, user_id,
+            )
+        else:
+            path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def set_group_join_policy(self, destination, group_id, requester_user_id,
+                              content):
+        """Sets the join policy for a group
+        """
+        path = _create_path(PREFIX, "/groups/%s/settings/m.join_policy", group_id,)
+
+        return self.client.put_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            data=content,
+            ignore_backoff=True,
+        )
+
+    @log_function
+    def delete_group_summary_user(self, destination, group_id, requester_user_id,
+                                  user_id, role_id):
+        """Delete a users entry in a group
+        """
+        if role_id:
+            path = _create_path(
+                PREFIX, "/groups/%s/summary/roles/%s/users/%s",
+                group_id, role_id, user_id,
+            )
+        else:
+            path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
+
+        return self.client.delete_json(
+            destination=destination,
+            path=path,
+            args={"requester_user_id": requester_user_id},
+            ignore_backoff=True,
+        )
+
+    def bulk_get_publicised_groups(self, destination, user_ids):
+        """Get the groups a list of users are publicising
+        """
+
+        path = PREFIX + "/get_groups_publicised"
+
+        content = {"user_ids": user_ids}
+
+        return self.client.post_json(
+            destination=destination,
+            path=path,
+            data=content,
+            ignore_backoff=True,
+        )
+
+
+def _create_path(prefix, path, *args):
+    """Creates a path from the prefix, path template and args. Ensures that
+    all args are url encoded.
+
+    Example:
+
+        _create_path(PREFIX, "/event/%s/", event_id)
+
+    Args:
+        prefix (str)
+        path (str): String template for the path
+        args: ([str]): Args to insert into path. Each arg will be url encoded
+
+    Returns:
+        str
+    """
+    return prefix + path % tuple(urllib.quote(arg, "") for arg in args)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index c840da834c..19d09f5422 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,7 +17,7 @@
 from twisted.internet import defer
 
 from synapse.api.urls import FEDERATION_PREFIX as PREFIX
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import Codes, SynapseError, FederationDeniedError
 from synapse.http.server import JsonResource
 from synapse.http.servlet import (
     parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
@@ -24,7 +25,8 @@ from synapse.http.servlet import (
 )
 from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.versionstring import get_version_string
-from synapse.types import ThirdPartyInstanceID
+from synapse.util.logcontext import run_in_background
+from synapse.types import ThirdPartyInstanceID, get_domain_from_id
 
 import functools
 import logging
@@ -79,6 +81,8 @@ class Authenticator(object):
     def __init__(self, hs):
         self.keyring = hs.get_keyring()
         self.server_name = hs.hostname
+        self.store = hs.get_datastore()
+        self.federation_domain_whitelist = hs.config.federation_domain_whitelist
 
     # A method just so we can pass 'self' as the authenticator to the Servlets
     @defer.inlineCallbacks
@@ -110,7 +114,7 @@ class Authenticator(object):
                 key = strip_quotes(param_dict["key"])
                 sig = strip_quotes(param_dict["sig"])
                 return (origin, key, sig)
-            except:
+            except Exception:
                 raise AuthenticationError(
                     400, "Malformed Authorization header", Codes.UNAUTHORIZED
                 )
@@ -128,6 +132,12 @@ class Authenticator(object):
                 json_request["origin"] = origin
                 json_request["signatures"].setdefault(origin, {})[key] = sig
 
+        if (
+            self.federation_domain_whitelist is not None and
+            origin not in self.federation_domain_whitelist
+        ):
+            raise FederationDeniedError(origin)
+
         if not json_request["signatures"]:
             raise NoAuthenticationError(
                 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
@@ -138,18 +148,30 @@ class Authenticator(object):
         logger.info("Request from %s", origin)
         request.authenticated_entity = origin
 
+        # If we get a valid signed request from the other side, its probably
+        # alive
+        retry_timings = yield self.store.get_destination_retry_timings(origin)
+        if retry_timings and retry_timings["retry_last_ts"]:
+            run_in_background(self._reset_retry_timings, origin)
+
         defer.returnValue(origin)
 
+    @defer.inlineCallbacks
+    def _reset_retry_timings(self, origin):
+        try:
+            logger.info("Marking origin %r as up", origin)
+            yield self.store.set_destination_retry_timings(origin, 0, 0)
+        except Exception:
+            logger.exception("Error resetting retry timings on %s", origin)
+
 
 class BaseFederationServlet(object):
     REQUIRE_AUTH = True
 
-    def __init__(self, handler, authenticator, ratelimiter, server_name,
-                 room_list_handler):
+    def __init__(self, handler, authenticator, ratelimiter, server_name):
         self.handler = handler
         self.authenticator = authenticator
         self.ratelimiter = ratelimiter
-        self.room_list_handler = room_list_handler
 
     def _wrap(self, func):
         authenticator = self.authenticator
@@ -170,7 +192,7 @@ class BaseFederationServlet(object):
                 if self.REQUIRE_AUTH:
                     logger.exception("authenticate_request failed")
                     raise
-            except:
+            except Exception:
                 logger.exception("authenticate_request failed")
                 raise
 
@@ -263,7 +285,7 @@ class FederationSendServlet(BaseFederationServlet):
             code, response = yield self.handler.on_incoming_transaction(
                 transaction_data
             )
-        except:
+        except Exception:
             logger.exception("on_incoming_transaction failed")
             raise
 
@@ -581,7 +603,7 @@ class PublicRoomList(BaseFederationServlet):
         else:
             network_tuple = ThirdPartyInstanceID(None, None)
 
-        data = yield self.room_list_handler.get_local_public_room_list(
+        data = yield self.handler.get_local_public_room_list(
             limit, since_token,
             network_tuple=network_tuple
         )
@@ -602,7 +624,550 @@ class FederationVersionServlet(BaseFederationServlet):
         }))
 
 
-SERVLET_CLASSES = (
+class FederationGroupsProfileServlet(BaseFederationServlet):
+    """Get/set the basic profile of a group on behalf of a user
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/profile$"
+
+    @defer.inlineCallbacks
+    def on_GET(self, origin, content, query, group_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.get_group_profile(
+            group_id, requester_user_id
+        )
+
+        defer.returnValue((200, new_content))
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.update_group_profile(
+            group_id, requester_user_id, content
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsSummaryServlet(BaseFederationServlet):
+    PATH = "/groups/(?P<group_id>[^/]*)/summary$"
+
+    @defer.inlineCallbacks
+    def on_GET(self, origin, content, query, group_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.get_group_summary(
+            group_id, requester_user_id
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsRoomsServlet(BaseFederationServlet):
+    """Get the rooms in a group on behalf of a user
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/rooms$"
+
+    @defer.inlineCallbacks
+    def on_GET(self, origin, content, query, group_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.get_rooms_in_group(
+            group_id, requester_user_id
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsAddRoomsServlet(BaseFederationServlet):
+    """Add/remove room from group
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)$"
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, room_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.add_room_to_group(
+            group_id, requester_user_id, room_id, content
+        )
+
+        defer.returnValue((200, new_content))
+
+    @defer.inlineCallbacks
+    def on_DELETE(self, origin, content, query, group_id, room_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.remove_room_from_group(
+            group_id, requester_user_id, room_id,
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
+    """Update room config in group
+    """
+    PATH = (
+        "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
+        "/config/(?P<config_key>[^/]*)$"
+    )
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, room_id, config_key):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        result = yield self.groups_handler.update_room_in_group(
+            group_id, requester_user_id, room_id, config_key, content,
+        )
+
+        defer.returnValue((200, result))
+
+
+class FederationGroupsUsersServlet(BaseFederationServlet):
+    """Get the users in a group on behalf of a user
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/users$"
+
+    @defer.inlineCallbacks
+    def on_GET(self, origin, content, query, group_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.get_users_in_group(
+            group_id, requester_user_id
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
+    """Get the users that have been invited to a group
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/invited_users$"
+
+    @defer.inlineCallbacks
+    def on_GET(self, origin, content, query, group_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.get_invited_users_in_group(
+            group_id, requester_user_id
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsInviteServlet(BaseFederationServlet):
+    """Ask a group server to invite someone to the group
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, user_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.invite_to_group(
+            group_id, user_id, requester_user_id, content,
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
+    """Accept an invitation from the group server
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite$"
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, user_id):
+        if get_domain_from_id(user_id) != origin:
+            raise SynapseError(403, "user_id doesn't match origin")
+
+        new_content = yield self.handler.accept_invite(
+            group_id, user_id, content,
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsJoinServlet(BaseFederationServlet):
+    """Attempt to join a group
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join$"
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, user_id):
+        if get_domain_from_id(user_id) != origin:
+            raise SynapseError(403, "user_id doesn't match origin")
+
+        new_content = yield self.handler.join_group(
+            group_id, user_id, content,
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsRemoveUserServlet(BaseFederationServlet):
+    """Leave or kick a user from the group
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, user_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.remove_user_from_group(
+            group_id, user_id, requester_user_id, content,
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsLocalInviteServlet(BaseFederationServlet):
+    """A group server has invited a local user
+    """
+    PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, user_id):
+        if get_domain_from_id(group_id) != origin:
+            raise SynapseError(403, "group_id doesn't match origin")
+
+        new_content = yield self.handler.on_invite(
+            group_id, user_id, content,
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
+    """A group server has removed a local user
+    """
+    PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, user_id):
+        if get_domain_from_id(group_id) != origin:
+            raise SynapseError(403, "user_id doesn't match origin")
+
+        new_content = yield self.handler.user_removed_from_group(
+            group_id, user_id, content,
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
+    """A group or user's server renews their attestation
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)$"
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, user_id):
+        # We don't need to check auth here as we check the attestation signatures
+
+        new_content = yield self.handler.on_renew_attestation(
+            group_id, user_id, content
+        )
+
+        defer.returnValue((200, new_content))
+
+
+class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
+    """Add/remove a room from the group summary, with optional category.
+
+    Matches both:
+        - /groups/:group/summary/rooms/:room_id
+        - /groups/:group/summary/categories/:category/rooms/:room_id
+    """
+    PATH = (
+        "/groups/(?P<group_id>[^/]*)/summary"
+        "(/categories/(?P<category_id>[^/]+))?"
+        "/rooms/(?P<room_id>[^/]*)$"
+    )
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, category_id, room_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        if category_id == "":
+            raise SynapseError(400, "category_id cannot be empty string")
+
+        resp = yield self.handler.update_group_summary_room(
+            group_id, requester_user_id,
+            room_id=room_id,
+            category_id=category_id,
+            content=content,
+        )
+
+        defer.returnValue((200, resp))
+
+    @defer.inlineCallbacks
+    def on_DELETE(self, origin, content, query, group_id, category_id, room_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        if category_id == "":
+            raise SynapseError(400, "category_id cannot be empty string")
+
+        resp = yield self.handler.delete_group_summary_room(
+            group_id, requester_user_id,
+            room_id=room_id,
+            category_id=category_id,
+        )
+
+        defer.returnValue((200, resp))
+
+
+class FederationGroupsCategoriesServlet(BaseFederationServlet):
+    """Get all categories for a group
+    """
+    PATH = (
+        "/groups/(?P<group_id>[^/]*)/categories/$"
+    )
+
+    @defer.inlineCallbacks
+    def on_GET(self, origin, content, query, group_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        resp = yield self.handler.get_group_categories(
+            group_id, requester_user_id,
+        )
+
+        defer.returnValue((200, resp))
+
+
+class FederationGroupsCategoryServlet(BaseFederationServlet):
+    """Add/remove/get a category in a group
+    """
+    PATH = (
+        "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
+    )
+
+    @defer.inlineCallbacks
+    def on_GET(self, origin, content, query, group_id, category_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        resp = yield self.handler.get_group_category(
+            group_id, requester_user_id, category_id
+        )
+
+        defer.returnValue((200, resp))
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, category_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        if category_id == "":
+            raise SynapseError(400, "category_id cannot be empty string")
+
+        resp = yield self.handler.upsert_group_category(
+            group_id, requester_user_id, category_id, content,
+        )
+
+        defer.returnValue((200, resp))
+
+    @defer.inlineCallbacks
+    def on_DELETE(self, origin, content, query, group_id, category_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        if category_id == "":
+            raise SynapseError(400, "category_id cannot be empty string")
+
+        resp = yield self.handler.delete_group_category(
+            group_id, requester_user_id, category_id,
+        )
+
+        defer.returnValue((200, resp))
+
+
+class FederationGroupsRolesServlet(BaseFederationServlet):
+    """Get roles in a group
+    """
+    PATH = (
+        "/groups/(?P<group_id>[^/]*)/roles/$"
+    )
+
+    @defer.inlineCallbacks
+    def on_GET(self, origin, content, query, group_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        resp = yield self.handler.get_group_roles(
+            group_id, requester_user_id,
+        )
+
+        defer.returnValue((200, resp))
+
+
+class FederationGroupsRoleServlet(BaseFederationServlet):
+    """Add/remove/get a role in a group
+    """
+    PATH = (
+        "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
+    )
+
+    @defer.inlineCallbacks
+    def on_GET(self, origin, content, query, group_id, role_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        resp = yield self.handler.get_group_role(
+            group_id, requester_user_id, role_id
+        )
+
+        defer.returnValue((200, resp))
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, role_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        if role_id == "":
+            raise SynapseError(400, "role_id cannot be empty string")
+
+        resp = yield self.handler.update_group_role(
+            group_id, requester_user_id, role_id, content,
+        )
+
+        defer.returnValue((200, resp))
+
+    @defer.inlineCallbacks
+    def on_DELETE(self, origin, content, query, group_id, role_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        if role_id == "":
+            raise SynapseError(400, "role_id cannot be empty string")
+
+        resp = yield self.handler.delete_group_role(
+            group_id, requester_user_id, role_id,
+        )
+
+        defer.returnValue((200, resp))
+
+
+class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
+    """Add/remove a user from the group summary, with optional role.
+
+    Matches both:
+        - /groups/:group/summary/users/:user_id
+        - /groups/:group/summary/roles/:role/users/:user_id
+    """
+    PATH = (
+        "/groups/(?P<group_id>[^/]*)/summary"
+        "(/roles/(?P<role_id>[^/]+))?"
+        "/users/(?P<user_id>[^/]*)$"
+    )
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query, group_id, role_id, user_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        if role_id == "":
+            raise SynapseError(400, "role_id cannot be empty string")
+
+        resp = yield self.handler.update_group_summary_user(
+            group_id, requester_user_id,
+            user_id=user_id,
+            role_id=role_id,
+            content=content,
+        )
+
+        defer.returnValue((200, resp))
+
+    @defer.inlineCallbacks
+    def on_DELETE(self, origin, content, query, group_id, role_id, user_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        if role_id == "":
+            raise SynapseError(400, "role_id cannot be empty string")
+
+        resp = yield self.handler.delete_group_summary_user(
+            group_id, requester_user_id,
+            user_id=user_id,
+            role_id=role_id,
+        )
+
+        defer.returnValue((200, resp))
+
+
+class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
+    """Get roles in a group
+    """
+    PATH = (
+        "/get_groups_publicised$"
+    )
+
+    @defer.inlineCallbacks
+    def on_POST(self, origin, content, query):
+        resp = yield self.handler.bulk_get_publicised_groups(
+            content["user_ids"], proxy=False,
+        )
+
+        defer.returnValue((200, resp))
+
+
+class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
+    """Sets whether a group is joinable without an invite or knock
+    """
+    PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy$"
+
+    @defer.inlineCallbacks
+    def on_PUT(self, origin, content, query, group_id):
+        requester_user_id = parse_string_from_args(query, "requester_user_id")
+        if get_domain_from_id(requester_user_id) != origin:
+            raise SynapseError(403, "requester_user_id doesn't match origin")
+
+        new_content = yield self.handler.set_group_join_policy(
+            group_id, requester_user_id, content
+        )
+
+        defer.returnValue((200, new_content))
+
+
+FEDERATION_SERVLET_CLASSES = (
     FederationSendServlet,
     FederationPullServlet,
     FederationEventServlet,
@@ -625,17 +1190,85 @@ SERVLET_CLASSES = (
     FederationThirdPartyInviteExchangeServlet,
     On3pidBindServlet,
     OpenIdUserInfo,
-    PublicRoomList,
     FederationVersionServlet,
 )
 
 
+ROOM_LIST_CLASSES = (
+    PublicRoomList,
+)
+
+GROUP_SERVER_SERVLET_CLASSES = (
+    FederationGroupsProfileServlet,
+    FederationGroupsSummaryServlet,
+    FederationGroupsRoomsServlet,
+    FederationGroupsUsersServlet,
+    FederationGroupsInvitedUsersServlet,
+    FederationGroupsInviteServlet,
+    FederationGroupsAcceptInviteServlet,
+    FederationGroupsJoinServlet,
+    FederationGroupsRemoveUserServlet,
+    FederationGroupsSummaryRoomsServlet,
+    FederationGroupsCategoriesServlet,
+    FederationGroupsCategoryServlet,
+    FederationGroupsRolesServlet,
+    FederationGroupsRoleServlet,
+    FederationGroupsSummaryUsersServlet,
+    FederationGroupsAddRoomsServlet,
+    FederationGroupsAddRoomsConfigServlet,
+    FederationGroupsSettingJoinPolicyServlet,
+)
+
+
+GROUP_LOCAL_SERVLET_CLASSES = (
+    FederationGroupsLocalInviteServlet,
+    FederationGroupsRemoveLocalUserServlet,
+    FederationGroupsBulkPublicisedServlet,
+)
+
+
+GROUP_ATTESTATION_SERVLET_CLASSES = (
+    FederationGroupsRenewAttestaionServlet,
+)
+
+
 def register_servlets(hs, resource, authenticator, ratelimiter):
-    for servletclass in SERVLET_CLASSES:
+    for servletclass in FEDERATION_SERVLET_CLASSES:
+        servletclass(
+            handler=hs.get_federation_server(),
+            authenticator=authenticator,
+            ratelimiter=ratelimiter,
+            server_name=hs.hostname,
+        ).register(resource)
+
+    for servletclass in ROOM_LIST_CLASSES:
+        servletclass(
+            handler=hs.get_room_list_handler(),
+            authenticator=authenticator,
+            ratelimiter=ratelimiter,
+            server_name=hs.hostname,
+        ).register(resource)
+
+    for servletclass in GROUP_SERVER_SERVLET_CLASSES:
+        servletclass(
+            handler=hs.get_groups_server_handler(),
+            authenticator=authenticator,
+            ratelimiter=ratelimiter,
+            server_name=hs.hostname,
+        ).register(resource)
+
+    for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
+        servletclass(
+            handler=hs.get_groups_local_handler(),
+            authenticator=authenticator,
+            ratelimiter=ratelimiter,
+            server_name=hs.hostname,
+        ).register(resource)
+
+    for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
         servletclass(
-            handler=hs.get_replication_layer(),
+            handler=hs.get_groups_attestation_renewer(),
             authenticator=authenticator,
             ratelimiter=ratelimiter,
             server_name=hs.hostname,
-            room_list_handler=hs.get_room_list_handler(),
         ).register(resource)
diff --git a/synapse/groups/__init__.py b/synapse/groups/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/synapse/groups/__init__.py
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
new file mode 100644
index 0000000000..6f11fa374b
--- /dev/null
+++ b/synapse/groups/attestations.py
@@ -0,0 +1,199 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Attestations ensure that users and groups can't lie about their memberships.
+
+When a user joins a group the HS and GS swap attestations, which allow them
+both to independently prove to third parties their membership.These
+attestations have a validity period so need to be periodically renewed.
+
+If a user leaves (or gets kicked out of) a group, either side can still use
+their attestation to "prove" their membership, until the attestation expires.
+Therefore attestations shouldn't be relied on to prove membership in important
+cases, but can for less important situtations, e.g. showing a users membership
+of groups on their profile, showing flairs, etc.abs
+
+An attestsation is a signed blob of json that looks like:
+
+    {
+        "user_id": "@foo:a.example.com",
+        "group_id": "+bar:b.example.com",
+        "valid_until_ms": 1507994728530,
+        "signatures":{"matrix.org":{"ed25519:auto":"..."}}
+    }
+"""
+
+import logging
+import random
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.types import get_domain_from_id
+from synapse.util.logcontext import run_in_background
+
+from signedjson.sign import sign_json
+
+
+logger = logging.getLogger(__name__)
+
+
+# Default validity duration for new attestations we create
+DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000
+
+# We add some jitter to the validity duration of attestations so that if we
+# add lots of users at once we don't need to renew them all at once.
+# The jitter is a multiplier picked randomly between the first and second number
+DEFAULT_ATTESTATION_JITTER = (0.9, 1.3)
+
+# Start trying to update our attestations when they come this close to expiring
+UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
+
+
+class GroupAttestationSigning(object):
+    """Creates and verifies group attestations.
+    """
+    def __init__(self, hs):
+        self.keyring = hs.get_keyring()
+        self.clock = hs.get_clock()
+        self.server_name = hs.hostname
+        self.signing_key = hs.config.signing_key[0]
+
+    @defer.inlineCallbacks
+    def verify_attestation(self, attestation, group_id, user_id, server_name=None):
+        """Verifies that the given attestation matches the given parameters.
+
+        An optional server_name can be supplied to explicitly set which server's
+        signature is expected. Otherwise assumes that either the group_id or user_id
+        is local and uses the other's server as the one to check.
+        """
+
+        if not server_name:
+            if get_domain_from_id(group_id) == self.server_name:
+                server_name = get_domain_from_id(user_id)
+            elif get_domain_from_id(user_id) == self.server_name:
+                server_name = get_domain_from_id(group_id)
+            else:
+                raise Exception("Expected either group_id or user_id to be local")
+
+        if user_id != attestation["user_id"]:
+            raise SynapseError(400, "Attestation has incorrect user_id")
+
+        if group_id != attestation["group_id"]:
+            raise SynapseError(400, "Attestation has incorrect group_id")
+        valid_until_ms = attestation["valid_until_ms"]
+
+        # TODO: We also want to check that *new* attestations that people give
+        # us to store are valid for at least a little while.
+        if valid_until_ms < self.clock.time_msec():
+            raise SynapseError(400, "Attestation expired")
+
+        yield self.keyring.verify_json_for_server(server_name, attestation)
+
+    def create_attestation(self, group_id, user_id):
+        """Create an attestation for the group_id and user_id with default
+        validity length.
+        """
+        validity_period = DEFAULT_ATTESTATION_LENGTH_MS
+        validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER)
+        valid_until_ms = int(self.clock.time_msec() + validity_period)
+
+        return sign_json({
+            "group_id": group_id,
+            "user_id": user_id,
+            "valid_until_ms": valid_until_ms,
+        }, self.server_name, self.signing_key)
+
+
+class GroupAttestionRenewer(object):
+    """Responsible for sending and receiving attestation updates.
+    """
+
+    def __init__(self, hs):
+        self.clock = hs.get_clock()
+        self.store = hs.get_datastore()
+        self.assestations = hs.get_groups_attestation_signing()
+        self.transport_client = hs.get_federation_transport_client()
+        self.is_mine_id = hs.is_mine_id
+        self.attestations = hs.get_groups_attestation_signing()
+
+        self._renew_attestations_loop = self.clock.looping_call(
+            self._renew_attestations, 30 * 60 * 1000,
+        )
+
+    @defer.inlineCallbacks
+    def on_renew_attestation(self, group_id, user_id, content):
+        """When a remote updates an attestation
+        """
+        attestation = content["attestation"]
+
+        if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
+            raise SynapseError(400, "Neither user not group are on this server")
+
+        yield self.attestations.verify_attestation(
+            attestation,
+            user_id=user_id,
+            group_id=group_id,
+        )
+
+        yield self.store.update_remote_attestion(group_id, user_id, attestation)
+
+        defer.returnValue({})
+
+    @defer.inlineCallbacks
+    def _renew_attestations(self):
+        """Called periodically to check if we need to update any of our attestations
+        """
+
+        now = self.clock.time_msec()
+
+        rows = yield self.store.get_attestations_need_renewals(
+            now + UPDATE_ATTESTATION_TIME_MS
+        )
+
+        @defer.inlineCallbacks
+        def _renew_attestation(group_id, user_id):
+            try:
+                if not self.is_mine_id(group_id):
+                    destination = get_domain_from_id(group_id)
+                elif not self.is_mine_id(user_id):
+                    destination = get_domain_from_id(user_id)
+                else:
+                    logger.warn(
+                        "Incorrectly trying to do attestations for user: %r in %r",
+                        user_id, group_id,
+                    )
+                    yield self.store.remove_attestation_renewal(group_id, user_id)
+                    return
+
+                attestation = self.attestations.create_attestation(group_id, user_id)
+
+                yield self.transport_client.renew_group_attestation(
+                    destination, group_id, user_id,
+                    content={"attestation": attestation},
+                )
+
+                yield self.store.update_attestation_renewal(
+                    group_id, user_id, attestation
+                )
+            except Exception:
+                logger.exception("Error renewing attestation of %r in %r",
+                                 user_id, group_id)
+
+        for row in rows:
+            group_id = row["group_id"]
+            user_id = row["user_id"]
+
+            run_in_background(_renew_attestation, group_id, user_id)
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
new file mode 100644
index 0000000000..2d95b04e0c
--- /dev/null
+++ b/synapse/groups/groups_server.py
@@ -0,0 +1,950 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.api.errors import SynapseError
+from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
+from twisted.internet import defer
+
+logger = logging.getLogger(__name__)
+
+
+# TODO: Allow users to "knock" or simpkly join depending on rules
+# TODO: Federation admin APIs
+# TODO: is_priveged flag to users and is_public to users and rooms
+# TODO: Audit log for admins (profile updates, membership changes, users who tried
+#       to join but were rejected, etc)
+# TODO: Flairs
+
+
+class GroupsServerHandler(object):
+    def __init__(self, hs):
+        self.hs = hs
+        self.store = hs.get_datastore()
+        self.room_list_handler = hs.get_room_list_handler()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.keyring = hs.get_keyring()
+        self.is_mine_id = hs.is_mine_id
+        self.signing_key = hs.config.signing_key[0]
+        self.server_name = hs.hostname
+        self.attestations = hs.get_groups_attestation_signing()
+        self.transport_client = hs.get_federation_transport_client()
+        self.profile_handler = hs.get_profile_handler()
+
+        # Ensure attestations get renewed
+        hs.get_groups_attestation_renewer()
+
+    @defer.inlineCallbacks
+    def check_group_is_ours(self, group_id, requester_user_id,
+                            and_exists=False, and_is_admin=None):
+        """Check that the group is ours, and optionally if it exists.
+
+        If group does exist then return group.
+
+        Args:
+            group_id (str)
+            and_exists (bool): whether to also check if group exists
+            and_is_admin (str): whether to also check if given str is a user_id
+                that is an admin
+        """
+        if not self.is_mine_id(group_id):
+            raise SynapseError(400, "Group not on this server")
+
+        group = yield self.store.get_group(group_id)
+        if and_exists and not group:
+            raise SynapseError(404, "Unknown group")
+
+        is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+        if group and not is_user_in_group and not group["is_public"]:
+            raise SynapseError(404, "Unknown group")
+
+        if and_is_admin:
+            is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin)
+            if not is_admin:
+                raise SynapseError(403, "User is not admin in group")
+
+        defer.returnValue(group)
+
+    @defer.inlineCallbacks
+    def get_group_summary(self, group_id, requester_user_id):
+        """Get the summary for a group as seen by requester_user_id.
+
+        The group summary consists of the profile of the room, and a curated
+        list of users and rooms. These list *may* be organised by role/category.
+        The roles/categories are ordered, and so are the users/rooms within them.
+
+        A user/room may appear in multiple roles/categories.
+        """
+        yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+        is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+
+        profile = yield self.get_group_profile(group_id, requester_user_id)
+
+        users, roles = yield self.store.get_users_for_summary_by_role(
+            group_id, include_private=is_user_in_group,
+        )
+
+        # TODO: Add profiles to users
+
+        rooms, categories = yield self.store.get_rooms_for_summary_by_category(
+            group_id, include_private=is_user_in_group,
+        )
+
+        for room_entry in rooms:
+            room_id = room_entry["room_id"]
+            joined_users = yield self.store.get_users_in_room(room_id)
+            entry = yield self.room_list_handler.generate_room_entry(
+                room_id, len(joined_users),
+                with_alias=False, allow_private=True,
+            )
+            entry = dict(entry)  # so we don't change whats cached
+            entry.pop("room_id", None)
+
+            room_entry["profile"] = entry
+
+        rooms.sort(key=lambda e: e.get("order", 0))
+
+        for entry in users:
+            user_id = entry["user_id"]
+
+            if not self.is_mine_id(requester_user_id):
+                attestation = yield self.store.get_remote_attestation(group_id, user_id)
+                if not attestation:
+                    continue
+
+                entry["attestation"] = attestation
+            else:
+                entry["attestation"] = self.attestations.create_attestation(
+                    group_id, user_id,
+                )
+
+            user_profile = yield self.profile_handler.get_profile_from_cache(user_id)
+            entry.update(user_profile)
+
+        users.sort(key=lambda e: e.get("order", 0))
+
+        membership_info = yield self.store.get_users_membership_info_in_group(
+            group_id, requester_user_id,
+        )
+
+        defer.returnValue({
+            "profile": profile,
+            "users_section": {
+                "users": users,
+                "roles": roles,
+                "total_user_count_estimate": 0,  # TODO
+            },
+            "rooms_section": {
+                "rooms": rooms,
+                "categories": categories,
+                "total_room_count_estimate": 0,  # TODO
+            },
+            "user": membership_info,
+        })
+
+    @defer.inlineCallbacks
+    def update_group_summary_room(self, group_id, requester_user_id,
+                                  room_id, category_id, content):
+        """Add/update a room to the group summary
+        """
+        yield self.check_group_is_ours(
+            group_id,
+            requester_user_id,
+            and_exists=True,
+            and_is_admin=requester_user_id,
+        )
+
+        RoomID.from_string(room_id)  # Ensure valid room id
+
+        order = content.get("order", None)
+
+        is_public = _parse_visibility_from_contents(content)
+
+        yield self.store.add_room_to_summary(
+            group_id=group_id,
+            room_id=room_id,
+            category_id=category_id,
+            order=order,
+            is_public=is_public,
+        )
+
+        defer.returnValue({})
+
+    @defer.inlineCallbacks
+    def delete_group_summary_room(self, group_id, requester_user_id,
+                                  room_id, category_id):
+        """Remove a room from the summary
+        """
+        yield self.check_group_is_ours(
+            group_id,
+            requester_user_id,
+            and_exists=True,
+            and_is_admin=requester_user_id,
+        )
+
+        yield self.store.remove_room_from_summary(
+            group_id=group_id,
+            room_id=room_id,
+            category_id=category_id,
+        )
+
+        defer.returnValue({})
+
+    @defer.inlineCallbacks
+    def set_group_join_policy(self, group_id, requester_user_id, content):
+        """Sets the group join policy.
+
+        Currently supported policies are:
+         - "invite": an invite must be received and accepted in order to join.
+         - "open": anyone can join.
+        """
+        yield self.check_group_is_ours(
+            group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
+        )
+
+        join_policy = _parse_join_policy_from_contents(content)
+        if join_policy is None:
+            raise SynapseError(
+                400, "No value specified for 'm.join_policy'"
+            )
+
+        yield self.store.set_group_join_policy(group_id, join_policy=join_policy)
+
+        defer.returnValue({})
+
+    @defer.inlineCallbacks
+    def get_group_categories(self, group_id, requester_user_id):
+        """Get all categories in a group (as seen by user)
+        """
+        yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+        categories = yield self.store.get_group_categories(
+            group_id=group_id,
+        )
+        defer.returnValue({"categories": categories})
+
+    @defer.inlineCallbacks
+    def get_group_category(self, group_id, requester_user_id, category_id):
+        """Get a specific category in a group (as seen by user)
+        """
+        yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+        res = yield self.store.get_group_category(
+            group_id=group_id,
+            category_id=category_id,
+        )
+
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def update_group_category(self, group_id, requester_user_id, category_id, content):
+        """Add/Update a group category
+        """
+        yield self.check_group_is_ours(
+            group_id,
+            requester_user_id,
+            and_exists=True,
+            and_is_admin=requester_user_id,
+        )
+
+        is_public = _parse_visibility_from_contents(content)
+        profile = content.get("profile")
+
+        yield self.store.upsert_group_category(
+            group_id=group_id,
+            category_id=category_id,
+            is_public=is_public,
+            profile=profile,
+        )
+
+        defer.returnValue({})
+
+    @defer.inlineCallbacks
+    def delete_group_category(self, group_id, requester_user_id, category_id):
+        """Delete a group category
+        """
+        yield self.check_group_is_ours(
+            group_id,
+            requester_user_id,
+            and_exists=True,
+            and_is_admin=requester_user_id
+        )
+
+        yield self.store.remove_group_category(
+            group_id=group_id,
+            category_id=category_id,
+        )
+
+        defer.returnValue({})
+
+    @defer.inlineCallbacks
+    def get_group_roles(self, group_id, requester_user_id):
+        """Get all roles in a group (as seen by user)
+        """
+        yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+        roles = yield self.store.get_group_roles(
+            group_id=group_id,
+        )
+        defer.returnValue({"roles": roles})
+
+    @defer.inlineCallbacks
+    def get_group_role(self, group_id, requester_user_id, role_id):
+        """Get a specific role in a group (as seen by user)
+        """
+        yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+        res = yield self.store.get_group_role(
+            group_id=group_id,
+            role_id=role_id,
+        )
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def update_group_role(self, group_id, requester_user_id, role_id, content):
+        """Add/update a role in a group
+        """
+        yield self.check_group_is_ours(
+            group_id,
+            requester_user_id,
+            and_exists=True,
+            and_is_admin=requester_user_id,
+        )
+
+        is_public = _parse_visibility_from_contents(content)
+
+        profile = content.get("profile")
+
+        yield self.store.upsert_group_role(
+            group_id=group_id,
+            role_id=role_id,
+            is_public=is_public,
+            profile=profile,
+        )
+
+        defer.returnValue({})
+
+    @defer.inlineCallbacks
+    def delete_group_role(self, group_id, requester_user_id, role_id):
+        """Remove role from group
+        """
+        yield self.check_group_is_ours(
+            group_id,
+            requester_user_id,
+            and_exists=True,
+            and_is_admin=requester_user_id,
+        )
+
+        yield self.store.remove_group_role(
+            group_id=group_id,
+            role_id=role_id,
+        )
+
+        defer.returnValue({})
+
+    @defer.inlineCallbacks
+    def update_group_summary_user(self, group_id, requester_user_id, user_id, role_id,
+                                  content):
+        """Add/update a users entry in the group summary
+        """
+        yield self.check_group_is_ours(
+            group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
+        )
+
+        order = content.get("order", None)
+
+        is_public = _parse_visibility_from_contents(content)
+
+        yield self.store.add_user_to_summary(
+            group_id=group_id,
+            user_id=user_id,
+            role_id=role_id,
+            order=order,
+            is_public=is_public,
+        )
+
+        defer.returnValue({})
+
+    @defer.inlineCallbacks
+    def delete_group_summary_user(self, group_id, requester_user_id, user_id, role_id):
+        """Remove a user from the group summary
+        """
+        yield self.check_group_is_ours(
+            group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
+        )
+
+        yield self.store.remove_user_from_summary(
+            group_id=group_id,
+            user_id=user_id,
+            role_id=role_id,
+        )
+
+        defer.returnValue({})
+
+    @defer.inlineCallbacks
+    def get_group_profile(self, group_id, requester_user_id):
+        """Get the group profile as seen by requester_user_id
+        """
+
+        yield self.check_group_is_ours(group_id, requester_user_id)
+
+        group = yield self.store.get_group(group_id)
+
+        if group:
+            cols = [
+                "name", "short_description", "long_description",
+                "avatar_url", "is_public",
+            ]
+            group_description = {key: group[key] for key in cols}
+            group_description["is_openly_joinable"] = group["join_policy"] == "open"
+
+            defer.returnValue(group_description)
+        else:
+            raise SynapseError(404, "Unknown group")
+
+    @defer.inlineCallbacks
+    def update_group_profile(self, group_id, requester_user_id, content):
+        """Update the group profile
+        """
+        yield self.check_group_is_ours(
+            group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
+        )
+
+        profile = {}
+        for keyname in ("name", "avatar_url", "short_description",
+                        "long_description"):
+            if keyname in content:
+                value = content[keyname]
+                if not isinstance(value, basestring):
+                    raise SynapseError(400, "%r value is not a string" % (keyname,))
+                profile[keyname] = value
+
+        yield self.store.update_group_profile(group_id, profile)
+
+    @defer.inlineCallbacks
+    def get_users_in_group(self, group_id, requester_user_id):
+        """Get the users in group as seen by requester_user_id.
+
+        The ordering is arbitrary at the moment
+        """
+
+        yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+        is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+
+        user_results = yield self.store.get_users_in_group(
+            group_id, include_private=is_user_in_group,
+        )
+
+        chunk = []
+        for user_result in user_results:
+            g_user_id = user_result["user_id"]
+            is_public = user_result["is_public"]
+            is_privileged = user_result["is_admin"]
+
+            entry = {"user_id": g_user_id}
+
+            profile = yield self.profile_handler.get_profile_from_cache(g_user_id)
+            entry.update(profile)
+
+            entry["is_public"] = bool(is_public)
+            entry["is_privileged"] = bool(is_privileged)
+
+            if not self.is_mine_id(g_user_id):
+                attestation = yield self.store.get_remote_attestation(group_id, g_user_id)
+                if not attestation:
+                    continue
+
+                entry["attestation"] = attestation
+            else:
+                entry["attestation"] = self.attestations.create_attestation(
+                    group_id, g_user_id,
+                )
+
+            chunk.append(entry)
+
+        # TODO: If admin add lists of users whose attestations have timed out
+
+        defer.returnValue({
+            "chunk": chunk,
+            "total_user_count_estimate": len(user_results),
+        })
+
+    @defer.inlineCallbacks
+    def get_invited_users_in_group(self, group_id, requester_user_id):
+        """Get the users that have been invited to a group as seen by requester_user_id.
+
+        The ordering is arbitrary at the moment
+        """
+
+        yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+        is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+
+        if not is_user_in_group:
+            raise SynapseError(403, "User not in group")
+
+        invited_users = yield self.store.get_invited_users_in_group(group_id)
+
+        user_profiles = []
+
+        for user_id in invited_users:
+            user_profile = {
+                "user_id": user_id
+            }
+            try:
+                profile = yield self.profile_handler.get_profile_from_cache(user_id)
+                user_profile.update(profile)
+            except Exception as e:
+                logger.warn("Error getting profile for %s: %s", user_id, e)
+            user_profiles.append(user_profile)
+
+        defer.returnValue({
+            "chunk": user_profiles,
+            "total_user_count_estimate": len(invited_users),
+        })
+
+    @defer.inlineCallbacks
+    def get_rooms_in_group(self, group_id, requester_user_id):
+        """Get the rooms in group as seen by requester_user_id
+
+        This returns rooms in order of decreasing number of joined users
+        """
+
+        yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+        is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+
+        room_results = yield self.store.get_rooms_in_group(
+            group_id, include_private=is_user_in_group,
+        )
+
+        chunk = []
+        for room_result in room_results:
+            room_id = room_result["room_id"]
+
+            joined_users = yield self.store.get_users_in_room(room_id)
+            entry = yield self.room_list_handler.generate_room_entry(
+                room_id, len(joined_users),
+                with_alias=False, allow_private=True,
+            )
+
+            if not entry:
+                continue
+
+            entry["is_public"] = bool(room_result["is_public"])
+
+            chunk.append(entry)
+
+        chunk.sort(key=lambda e: -e["num_joined_members"])
+
+        defer.returnValue({
+            "chunk": chunk,
+            "total_room_count_estimate": len(room_results),
+        })
+
+    @defer.inlineCallbacks
+    def add_room_to_group(self, group_id, requester_user_id, room_id, content):
+        """Add room to group
+        """
+        RoomID.from_string(room_id)  # Ensure valid room id
+
+        yield self.check_group_is_ours(
+            group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
+        )
+
+        is_public = _parse_visibility_from_contents(content)
+
+        yield self.store.add_room_to_group(group_id, room_id, is_public=is_public)
+
+        defer.returnValue({})
+
+    @defer.inlineCallbacks
+    def update_room_in_group(self, group_id, requester_user_id, room_id, config_key,
+                             content):
+        """Update room in group
+        """
+        RoomID.from_string(room_id)  # Ensure valid room id
+
+        yield self.check_group_is_ours(
+            group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
+        )
+
+        if config_key == "m.visibility":
+            is_public = _parse_visibility_dict(content)
+
+            yield self.store.update_room_in_group_visibility(
+                group_id, room_id,
+                is_public=is_public,
+            )
+        else:
+            raise SynapseError(400, "Uknown config option")
+
+        defer.returnValue({})
+
+    @defer.inlineCallbacks
+    def remove_room_from_group(self, group_id, requester_user_id, room_id):
+        """Remove room from group
+        """
+        yield self.check_group_is_ours(
+            group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
+        )
+
+        yield self.store.remove_room_from_group(group_id, room_id)
+
+        defer.returnValue({})
+
+    @defer.inlineCallbacks
+    def invite_to_group(self, group_id, user_id, requester_user_id, content):
+        """Invite user to group
+        """
+
+        group = yield self.check_group_is_ours(
+            group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
+        )
+
+        # TODO: Check if user knocked
+        # TODO: Check if user is already invited
+
+        content = {
+            "profile": {
+                "name": group["name"],
+                "avatar_url": group["avatar_url"],
+            },
+            "inviter": requester_user_id,
+        }
+
+        if self.hs.is_mine_id(user_id):
+            groups_local = self.hs.get_groups_local_handler()
+            res = yield groups_local.on_invite(group_id, user_id, content)
+            local_attestation = None
+        else:
+            local_attestation = self.attestations.create_attestation(group_id, user_id)
+            content.update({
+                "attestation": local_attestation,
+            })
+
+            res = yield self.transport_client.invite_to_group_notification(
+                get_domain_from_id(user_id), group_id, user_id, content
+            )
+
+            user_profile = res.get("user_profile", {})
+            yield self.store.add_remote_profile_cache(
+                user_id,
+                displayname=user_profile.get("displayname"),
+                avatar_url=user_profile.get("avatar_url"),
+            )
+
+        if res["state"] == "join":
+            if not self.hs.is_mine_id(user_id):
+                remote_attestation = res["attestation"]
+
+                yield self.attestations.verify_attestation(
+                    remote_attestation,
+                    user_id=user_id,
+                    group_id=group_id,
+                )
+            else:
+                remote_attestation = None
+
+            yield self.store.add_user_to_group(
+                group_id, user_id,
+                is_admin=False,
+                is_public=False,  # TODO
+                local_attestation=local_attestation,
+                remote_attestation=remote_attestation,
+            )
+        elif res["state"] == "invite":
+            yield self.store.add_group_invite(
+                group_id, user_id,
+            )
+            defer.returnValue({
+                "state": "invite"
+            })
+        elif res["state"] == "reject":
+            defer.returnValue({
+                "state": "reject"
+            })
+        else:
+            raise SynapseError(502, "Unknown state returned by HS")
+
+    @defer.inlineCallbacks
+    def _add_user(self, group_id, user_id, content):
+        """Add a user to a group based on a content dict.
+
+        See accept_invite, join_group.
+        """
+        if not self.hs.is_mine_id(user_id):
+            local_attestation = self.attestations.create_attestation(
+                group_id, user_id,
+            )
+
+            remote_attestation = content["attestation"]
+
+            yield self.attestations.verify_attestation(
+                remote_attestation,
+                user_id=user_id,
+                group_id=group_id,
+            )
+        else:
+            local_attestation = None
+            remote_attestation = None
+
+        is_public = _parse_visibility_from_contents(content)
+
+        yield self.store.add_user_to_group(
+            group_id, user_id,
+            is_admin=False,
+            is_public=is_public,
+            local_attestation=local_attestation,
+            remote_attestation=remote_attestation,
+        )
+
+        defer.returnValue(local_attestation)
+
+    @defer.inlineCallbacks
+    def accept_invite(self, group_id, requester_user_id, content):
+        """User tries to accept an invite to the group.
+
+        This is different from them asking to join, and so should error if no
+        invite exists (and they're not a member of the group)
+        """
+
+        yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+        is_invited = yield self.store.is_user_invited_to_local_group(
+            group_id, requester_user_id,
+        )
+        if not is_invited:
+            raise SynapseError(403, "User not invited to group")
+
+        local_attestation = yield self._add_user(group_id, requester_user_id, content)
+
+        defer.returnValue({
+            "state": "join",
+            "attestation": local_attestation,
+        })
+
+    @defer.inlineCallbacks
+    def join_group(self, group_id, requester_user_id, content):
+        """User tries to join the group.
+
+        This will error if the group requires an invite/knock to join
+        """
+
+        group_info = yield self.check_group_is_ours(
+            group_id, requester_user_id, and_exists=True
+        )
+        if group_info['join_policy'] != "open":
+            raise SynapseError(403, "Group is not publicly joinable")
+
+        local_attestation = yield self._add_user(group_id, requester_user_id, content)
+
+        defer.returnValue({
+            "state": "join",
+            "attestation": local_attestation,
+        })
+
+    @defer.inlineCallbacks
+    def knock(self, group_id, requester_user_id, content):
+        """A user requests becoming a member of the group
+        """
+        yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+        raise NotImplementedError()
+
+    @defer.inlineCallbacks
+    def accept_knock(self, group_id, requester_user_id, content):
+        """Accept a users knock to the room.
+
+        Errors if the user hasn't knocked, rather than inviting them.
+        """
+
+        yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+        raise NotImplementedError()
+
+    @defer.inlineCallbacks
+    def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
+        """Remove a user from the group; either a user is leaving or an admin
+        kicked them.
+        """
+
+        yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+        is_kick = False
+        if requester_user_id != user_id:
+            is_admin = yield self.store.is_user_admin_in_group(
+                group_id, requester_user_id
+            )
+            if not is_admin:
+                raise SynapseError(403, "User is not admin in group")
+
+            is_kick = True
+
+        yield self.store.remove_user_from_group(
+            group_id, user_id,
+        )
+
+        if is_kick:
+            if self.hs.is_mine_id(user_id):
+                groups_local = self.hs.get_groups_local_handler()
+                yield groups_local.user_removed_from_group(group_id, user_id, {})
+            else:
+                yield self.transport_client.remove_user_from_group_notification(
+                    get_domain_from_id(user_id), group_id, user_id, {}
+                )
+
+        if not self.hs.is_mine_id(user_id):
+            yield self.store.maybe_delete_remote_profile_cache(user_id)
+
+        defer.returnValue({})
+
+    @defer.inlineCallbacks
+    def create_group(self, group_id, requester_user_id, content):
+        group = yield self.check_group_is_ours(group_id, requester_user_id)
+
+        logger.info("Attempting to create group with ID: %r", group_id)
+
+        # parsing the id into a GroupID validates it.
+        group_id_obj = GroupID.from_string(group_id)
+
+        if group:
+            raise SynapseError(400, "Group already exists")
+
+        is_admin = yield self.auth.is_server_admin(UserID.from_string(requester_user_id))
+        if not is_admin:
+            if not self.hs.config.enable_group_creation:
+                raise SynapseError(
+                    403, "Only a server admin can create groups on this server",
+                )
+            localpart = group_id_obj.localpart
+            if not localpart.startswith(self.hs.config.group_creation_prefix):
+                raise SynapseError(
+                    400,
+                    "Can only create groups with prefix %r on this server" % (
+                        self.hs.config.group_creation_prefix,
+                    ),
+                )
+
+        profile = content.get("profile", {})
+        name = profile.get("name")
+        avatar_url = profile.get("avatar_url")
+        short_description = profile.get("short_description")
+        long_description = profile.get("long_description")
+        user_profile = content.get("user_profile", {})
+
+        yield self.store.create_group(
+            group_id,
+            requester_user_id,
+            name=name,
+            avatar_url=avatar_url,
+            short_description=short_description,
+            long_description=long_description,
+        )
+
+        if not self.hs.is_mine_id(requester_user_id):
+            remote_attestation = content["attestation"]
+
+            yield self.attestations.verify_attestation(
+                remote_attestation,
+                user_id=requester_user_id,
+                group_id=group_id,
+            )
+
+            local_attestation = self.attestations.create_attestation(
+                group_id,
+                requester_user_id,
+            )
+        else:
+            local_attestation = None
+            remote_attestation = None
+
+        yield self.store.add_user_to_group(
+            group_id, requester_user_id,
+            is_admin=True,
+            is_public=True,  # TODO
+            local_attestation=local_attestation,
+            remote_attestation=remote_attestation,
+        )
+
+        if not self.hs.is_mine_id(requester_user_id):
+            yield self.store.add_remote_profile_cache(
+                requester_user_id,
+                displayname=user_profile.get("displayname"),
+                avatar_url=user_profile.get("avatar_url"),
+            )
+
+        defer.returnValue({
+            "group_id": group_id,
+        })
+
+
+def _parse_join_policy_from_contents(content):
+    """Given a content for a request, return the specified join policy or None
+    """
+
+    join_policy_dict = content.get("m.join_policy")
+    if join_policy_dict:
+        return _parse_join_policy_dict(join_policy_dict)
+    else:
+        return None
+
+
+def _parse_join_policy_dict(join_policy_dict):
+    """Given a dict for the "m.join_policy" config return the join policy specified
+    """
+    join_policy_type = join_policy_dict.get("type")
+    if not join_policy_type:
+        return "invite"
+
+    if join_policy_type not in ("invite", "open"):
+        raise SynapseError(
+            400, "Synapse only supports 'invite'/'open' join rule"
+        )
+    return join_policy_type
+
+
+def _parse_visibility_from_contents(content):
+    """Given a content for a request parse out whether the entity should be
+    public or not
+    """
+
+    visibility = content.get("m.visibility")
+    if visibility:
+        return _parse_visibility_dict(visibility)
+    else:
+        is_public = True
+
+    return is_public
+
+
+def _parse_visibility_dict(visibility):
+    """Given a dict for the "m.visibility" config return if the entity should
+    be public or not
+    """
+    vis_type = visibility.get("type")
+    if not vis_type:
+        return True
+
+    if vis_type not in ("public", "private"):
+        raise SynapseError(
+            400, "Synapse only supports 'public'/'private' visibility"
+        )
+    return vis_type == "public"
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index 5ad408f549..8f8fd82eb0 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -17,10 +17,8 @@ from .register import RegistrationHandler
 from .room import (
     RoomCreationHandler, RoomContextHandler,
 )
-from .room_member import RoomMemberHandler
 from .message import MessageHandler
 from .federation import FederationHandler
-from .profile import ProfileHandler
 from .directory import DirectoryHandler
 from .admin import AdminHandler
 from .identity import IdentityHandler
@@ -50,9 +48,7 @@ class Handlers(object):
         self.registration_handler = RegistrationHandler(hs)
         self.message_handler = MessageHandler(hs)
         self.room_creation_handler = RoomCreationHandler(hs)
-        self.room_member_handler = RoomMemberHandler(hs)
         self.federation_handler = FederationHandler(hs)
-        self.profile_handler = ProfileHandler(hs)
         self.directory_handler = DirectoryHandler(hs)
         self.admin_handler = AdminHandler(hs)
         self.identity_handler = IdentityHandler(hs)
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index e83adc8339..e089e66fde 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -53,7 +53,20 @@ class BaseHandler(object):
 
         self.event_builder_factory = hs.get_event_builder_factory()
 
-    def ratelimit(self, requester):
+    @defer.inlineCallbacks
+    def ratelimit(self, requester, update=True):
+        """Ratelimits requests.
+
+        Args:
+            requester (Requester)
+            update (bool): Whether to record that a request is being processed.
+                Set to False when doing multiple checks for one request (e.g.
+                to check up front if we would reject the request), and set to
+                True for the last call for a given request.
+
+        Raises:
+            LimitExceededError if the request should be ratelimited
+        """
         time_now = self.clock.time()
         user_id = requester.user.to_string()
 
@@ -67,10 +80,25 @@ class BaseHandler(object):
         if requester.app_service and not requester.app_service.is_rate_limited():
             return
 
+        # Check if there is a per user override in the DB.
+        override = yield self.store.get_ratelimit_for_user(user_id)
+        if override:
+            # If overriden with a null Hz then ratelimiting has been entirely
+            # disabled for the user
+            if not override.messages_per_second:
+                return
+
+            messages_per_second = override.messages_per_second
+            burst_count = override.burst_count
+        else:
+            messages_per_second = self.hs.config.rc_messages_per_second
+            burst_count = self.hs.config.rc_message_burst_count
+
         allowed, time_allowed = self.ratelimiter.send_message(
             user_id, time_now,
-            msg_rate_hz=self.hs.config.rc_messages_per_second,
-            burst_count=self.hs.config.rc_message_burst_count,
+            msg_rate_hz=messages_per_second,
+            burst_count=burst_count,
+            update=update,
         )
         if not allowed:
             raise LimitExceededError(
@@ -130,7 +158,7 @@ class BaseHandler(object):
                 # homeserver.
                 requester = synapse.types.create_requester(
                     target_user, is_guest=True)
-                handler = self.hs.get_handlers().room_member_handler
+                handler = self.hs.get_room_member_handler()
                 yield handler.update_membership(
                     requester,
                     target_user,
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 05af54d31b..b596f098fd 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -15,14 +15,21 @@
 
 from twisted.internet import defer
 
+import synapse
 from synapse.api.constants import EventTypes
 from synapse.util.metrics import Measure
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.util.logcontext import (
+    make_deferred_yieldable, run_in_background,
+)
 
 import logging
 
 logger = logging.getLogger(__name__)
 
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+events_processed_counter = metrics.register_counter("events_processed")
+
 
 def log_failure(failure):
     logger.error(
@@ -70,21 +77,25 @@ class ApplicationServicesHandler(object):
         with Measure(self.clock, "notify_interested_services"):
             self.is_processing = True
             try:
-                upper_bound = self.current_max
                 limit = 100
                 while True:
                     upper_bound, events = yield self.store.get_new_events_for_appservice(
-                        upper_bound, limit
+                        self.current_max, limit
                     )
 
                     if not events:
                         break
 
+                    events_by_room = {}
                     for event in events:
+                        events_by_room.setdefault(event.room_id, []).append(event)
+
+                    @defer.inlineCallbacks
+                    def handle_event(event):
                         # Gather interested services
                         services = yield self._get_services_for_event(event)
                         if len(services) == 0:
-                            continue  # no services need notifying
+                            return  # no services need notifying
 
                         # Do we know this user exists? If not, poke the user
                         # query API for all services which match that user regex.
@@ -100,14 +111,35 @@ class ApplicationServicesHandler(object):
 
                         # Fork off pushes to these services
                         for service in services:
-                            preserve_fn(self.scheduler.submit_event_for_as)(
-                                service, event
-                            )
+                            self.scheduler.submit_event_for_as(service, event)
+
+                    @defer.inlineCallbacks
+                    def handle_room_events(events):
+                        for event in events:
+                            yield handle_event(event)
+
+                    yield make_deferred_yieldable(defer.gatherResults([
+                        run_in_background(handle_room_events, evs)
+                        for evs in events_by_room.itervalues()
+                    ], consumeErrors=True))
 
                     yield self.store.set_appservice_last_pos(upper_bound)
 
-                    if len(events) < limit:
-                        break
+                    now = self.clock.time_msec()
+                    ts = yield self.store.get_received_ts(events[-1].event_id)
+
+                    synapse.metrics.event_processing_positions.set(
+                        upper_bound, "appservice_sender",
+                    )
+
+                    events_processed_counter.inc_by(len(events))
+
+                    synapse.metrics.event_processing_lag.set(
+                        now - ts, "appservice_sender",
+                    )
+                    synapse.metrics.event_processing_last_ts.set(
+                        ts, "appservice_sender",
+                    )
             finally:
                 self.is_processing = False
 
@@ -163,8 +195,11 @@ class ApplicationServicesHandler(object):
     def query_3pe(self, kind, protocol, fields):
         services = yield self._get_services_for_3pn(protocol)
 
-        results = yield preserve_context_over_deferred(defer.DeferredList([
-            preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
+        results = yield make_deferred_yieldable(defer.DeferredList([
+            run_in_background(
+                self.appservice_api.query_3pe,
+                service, kind, protocol, fields,
+            )
             for service in services
         ], consumeErrors=True))
 
@@ -225,11 +260,15 @@ class ApplicationServicesHandler(object):
             event based on the service regex.
         """
         services = self.store.get_app_services()
-        interested_list = [
-            s for s in services if (
-                yield s.is_interested(event, self.store)
-            )
-        ]
+
+        # we can't use a list comprehension here. Since python 3, list
+        # comprehensions use a generator internally. This means you can't yield
+        # inside of a list comprehension anymore.
+        interested_list = []
+        for s in services:
+            if (yield s.is_interested(event, self.store)):
+                interested_list.append(s)
+
         defer.returnValue(interested_list)
 
     def _get_services_for_user(self, user_id):
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index fffba34383..a5365c4fe4 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,14 +13,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-from twisted.internet import defer
+from twisted.internet import defer, threads
 
 from ._base import BaseHandler
 from synapse.api.constants import LoginType
+from synapse.api.errors import (
+    AuthError, Codes, InteractiveAuthIncompleteError, LoginError, StoreError,
+    SynapseError,
+)
+from synapse.module_api import ModuleApi
 from synapse.types import UserID
-from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
 from synapse.util.async import run_on_reactor
+from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.logcontext import make_deferred_yieldable
 
 from twisted.web.client import PartialDownloadError
 
@@ -44,18 +50,23 @@ class AuthHandler(BaseHandler):
         """
         super(AuthHandler, self).__init__(hs)
         self.checkers = {
-            LoginType.PASSWORD: self._check_password_auth,
             LoginType.RECAPTCHA: self._check_recaptcha,
             LoginType.EMAIL_IDENTITY: self._check_email_identity,
+            LoginType.MSISDN: self._check_msisdn,
             LoginType.DUMMY: self._check_dummy_auth,
         }
         self.bcrypt_rounds = hs.config.bcrypt_rounds
-        self.sessions = {}
 
-        account_handler = _AccountHandler(
-            hs, check_user_exists=self.check_user_exists
+        # This is not a cache per se, but a store of all current sessions that
+        # expire after N hours
+        self.sessions = ExpiringCache(
+            cache_name="register_sessions",
+            clock=hs.get_clock(),
+            expiry_ms=self.SESSION_EXPIRE_MS,
+            reset_expiry_on_get=True,
         )
 
+        account_handler = ModuleApi(hs, self)
         self.password_providers = [
             module(config=config, account_handler=account_handler)
             for module, config in hs.config.password_providers
@@ -64,39 +75,120 @@ class AuthHandler(BaseHandler):
         logger.info("Extra password_providers: %r", self.password_providers)
 
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
-        self.device_handler = hs.get_device_handler()
         self.macaroon_gen = hs.get_macaroon_generator()
+        self._password_enabled = hs.config.password_enabled
+
+        # we keep this as a list despite the O(N^2) implication so that we can
+        # keep PASSWORD first and avoid confusing clients which pick the first
+        # type in the list. (NB that the spec doesn't require us to do so and
+        # clients which favour types that they don't understand over those that
+        # they do are technically broken)
+        login_types = []
+        if self._password_enabled:
+            login_types.append(LoginType.PASSWORD)
+        for provider in self.password_providers:
+            if hasattr(provider, "get_supported_login_types"):
+                for t in provider.get_supported_login_types().keys():
+                    if t not in login_types:
+                        login_types.append(t)
+        self._supported_login_types = login_types
+
+    @defer.inlineCallbacks
+    def validate_user_via_ui_auth(self, requester, request_body, clientip):
+        """
+        Checks that the user is who they claim to be, via a UI auth.
+
+        This is used for things like device deletion and password reset where
+        the user already has a valid access token, but we want to double-check
+        that it isn't stolen by re-authenticating them.
+
+        Args:
+            requester (Requester): The user, as given by the access token
+
+            request_body (dict): The body of the request sent by the client
+
+            clientip (str): The IP address of the client.
+
+        Returns:
+            defer.Deferred[dict]: the parameters for this request (which may
+                have been given only in a previous call).
+
+        Raises:
+            InteractiveAuthIncompleteError if the client has not yet completed
+                any of the permitted login flows
+
+            AuthError if the client has completed a login flow, and it gives
+                a different user to `requester`
+        """
+
+        # build a list of supported flows
+        flows = [
+            [login_type] for login_type in self._supported_login_types
+        ]
+
+        result, params, _ = yield self.check_auth(
+            flows, request_body, clientip,
+        )
+
+        # find the completed login type
+        for login_type in self._supported_login_types:
+            if login_type not in result:
+                continue
+
+            user_id = result[login_type]
+            break
+        else:
+            # this can't happen
+            raise Exception(
+                "check_auth returned True but no successful login type",
+            )
+
+        # check that the UI auth matched the access token
+        if user_id != requester.user.to_string():
+            raise AuthError(403, "Invalid auth")
+
+        defer.returnValue(params)
 
     @defer.inlineCallbacks
     def check_auth(self, flows, clientdict, clientip):
         """
         Takes a dictionary sent by the client in the login / registration
-        protocol and handles the login flow.
+        protocol and handles the User-Interactive Auth flow.
 
         As a side effect, this function fills in the 'creds' key on the user's
         session with a map, which maps each auth-type (str) to the relevant
         identity authenticated by that auth-type (mostly str, but for captcha, bool).
 
+        If no auth flows have been completed successfully, raises an
+        InteractiveAuthIncompleteError. To handle this, you can use
+        synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
+        decorator.
+
         Args:
             flows (list): A list of login flows. Each flow is an ordered list of
                           strings representing auth-types. At least one full
                           flow must be completed in order for auth to be successful.
+
             clientdict: The dictionary from the client root level, not the
                         'auth' key: this method prompts for auth if none is sent.
+
             clientip (str): The IP address of the client.
+
         Returns:
-            A tuple of (authed, dict, dict, session_id) where authed is true if
-            the client has successfully completed an auth flow. If it is true
-            the first dict contains the authenticated credentials of each stage.
+            defer.Deferred[dict, dict, str]: a deferred tuple of
+                (creds, params, session_id).
+
+                'creds' contains the authenticated credentials of each stage.
 
-            If authed is false, the first dictionary is the server response to
-            the login request and should be passed back to the client.
+                'params' contains the parameters for this request (which may
+                have been given only in a previous call).
 
-            In either case, the second dict contains the parameters for this
-            request (which may have been given only in a previous call).
+                'session_id' is the ID of this session, either passed in by the
+                client or assigned by this call
 
-            session_id is the ID of this session, either passed in by the client
-            or assigned by the call to check_auth
+        Raises:
+            InteractiveAuthIncompleteError if the client has not yet completed
+                all the stages in any of the permitted flows.
         """
 
         authdict = None
@@ -124,11 +216,8 @@ class AuthHandler(BaseHandler):
             clientdict = session['clientdict']
 
         if not authdict:
-            defer.returnValue(
-                (
-                    False, self._auth_dict_for_flows(flows, session),
-                    clientdict, session['id']
-                )
+            raise InteractiveAuthIncompleteError(
+                self._auth_dict_for_flows(flows, session),
             )
 
         if 'creds' not in session:
@@ -139,14 +228,12 @@ class AuthHandler(BaseHandler):
         errordict = {}
         if 'type' in authdict:
             login_type = authdict['type']
-            if login_type not in self.checkers:
-                raise LoginError(400, "", Codes.UNRECOGNIZED)
             try:
-                result = yield self.checkers[login_type](authdict, clientip)
+                result = yield self._check_auth_dict(authdict, clientip)
                 if result:
                     creds[login_type] = result
                     self._save_session(session)
-            except LoginError, e:
+            except LoginError as e:
                 if login_type == LoginType.EMAIL_IDENTITY:
                     # riot used to have a bug where it would request a new
                     # validation token (thus sending a new email) each time it
@@ -155,7 +242,7 @@ class AuthHandler(BaseHandler):
                     #
                     # Grandfather in the old behaviour for now to avoid
                     # breaking old riot deployments.
-                    raise e
+                    raise
 
                 # this step failed. Merge the error dict into the response
                 # so that the client can have another go.
@@ -172,12 +259,14 @@ class AuthHandler(BaseHandler):
                     "Auth completed with creds: %r. Client dict has keys: %r",
                     creds, clientdict.keys()
                 )
-                defer.returnValue((True, creds, clientdict, session['id']))
+                defer.returnValue((creds, clientdict, session['id']))
 
         ret = self._auth_dict_for_flows(flows, session)
         ret['completed'] = creds.keys()
         ret.update(errordict)
-        defer.returnValue((False, ret, clientdict, session['id']))
+        raise InteractiveAuthIncompleteError(
+            ret,
+        )
 
     @defer.inlineCallbacks
     def add_oob_auth(self, stagetype, authdict, clientip):
@@ -249,16 +338,37 @@ class AuthHandler(BaseHandler):
         sess = self._get_session_info(session_id)
         return sess.setdefault('serverdict', {}).get(key, default)
 
-    def _check_password_auth(self, authdict, _):
-        if "user" not in authdict or "password" not in authdict:
-            raise LoginError(400, "", Codes.MISSING_PARAM)
+    @defer.inlineCallbacks
+    def _check_auth_dict(self, authdict, clientip):
+        """Attempt to validate the auth dict provided by a client
 
-        user_id = authdict["user"]
-        password = authdict["password"]
-        if not user_id.startswith('@'):
-            user_id = UserID.create(user_id, self.hs.hostname).to_string()
+        Args:
+            authdict (object): auth dict provided by the client
+            clientip (str): IP address of the client
 
-        return self._check_password(user_id, password)
+        Returns:
+            Deferred: result of the stage verification.
+
+        Raises:
+            StoreError if there was a problem accessing the database
+            SynapseError if there was a problem with the request
+            LoginError if there was an authentication problem.
+        """
+        login_type = authdict['type']
+        checker = self.checkers.get(login_type)
+        if checker is not None:
+            res = yield checker(authdict, clientip)
+            defer.returnValue(res)
+
+        # build a v1-login-style dict out of the authdict and fall back to the
+        # v1 code
+        user_id = authdict.get("user")
+
+        if user_id is None:
+            raise SynapseError(400, "", Codes.MISSING_PARAM)
+
+        (canonical_id, callback) = yield self.validate_login(user_id, authdict)
+        defer.returnValue(canonical_id)
 
     @defer.inlineCallbacks
     def _check_recaptcha(self, authdict, clientip):
@@ -307,31 +417,47 @@ class AuthHandler(BaseHandler):
                 defer.returnValue(True)
         raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
 
-    @defer.inlineCallbacks
     def _check_email_identity(self, authdict, _):
+        return self._check_threepid('email', authdict)
+
+    def _check_msisdn(self, authdict, _):
+        return self._check_threepid('msisdn', authdict)
+
+    @defer.inlineCallbacks
+    def _check_dummy_auth(self, authdict, _):
+        yield run_on_reactor()
+        defer.returnValue(True)
+
+    @defer.inlineCallbacks
+    def _check_threepid(self, medium, authdict):
         yield run_on_reactor()
 
         if 'threepid_creds' not in authdict:
             raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
 
         threepid_creds = authdict['threepid_creds']
+
         identity_handler = self.hs.get_handlers().identity_handler
 
-        logger.info("Getting validated threepid. threepidcreds: %r" % (threepid_creds,))
+        logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
         threepid = yield identity_handler.threepid_from_creds(threepid_creds)
 
         if not threepid:
             raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
 
+        if threepid['medium'] != medium:
+            raise LoginError(
+                401,
+                "Expecting threepid of type '%s', got '%s'" % (
+                    medium, threepid['medium'],
+                ),
+                errcode=Codes.UNAUTHORIZED
+            )
+
         threepid['threepid_creds'] = authdict['threepid_creds']
 
         defer.returnValue(threepid)
 
-    @defer.inlineCallbacks
-    def _check_dummy_auth(self, authdict, _):
-        yield run_on_reactor()
-        defer.returnValue(True)
-
     def _get_params_recaptcha(self):
         return {"public_key": self.hs.config.recaptcha_public_key}
 
@@ -371,26 +497,8 @@ class AuthHandler(BaseHandler):
 
         return self.sessions[session_id]
 
-    def validate_password_login(self, user_id, password):
-        """
-        Authenticates the user with their username and password.
-
-        Used only by the v1 login API.
-
-        Args:
-            user_id (str): complete @user:id
-            password (str): Password
-        Returns:
-            defer.Deferred: (str) canonical user id
-        Raises:
-            StoreError if there was a problem accessing the database
-            LoginError if there was an authentication problem.
-        """
-        return self._check_password(user_id, password)
-
     @defer.inlineCallbacks
-    def get_access_token_for_user_id(self, user_id, device_id=None,
-                                     initial_display_name=None):
+    def get_access_token_for_user_id(self, user_id, device_id=None):
         """
         Creates a new access token for the user with the given user ID.
 
@@ -404,13 +512,10 @@ class AuthHandler(BaseHandler):
             device_id (str|None): the device ID to associate with the tokens.
                None to leave the tokens unassociated with a device (deprecated:
                we should always have a device ID)
-            initial_display_name (str): display name to associate with the
-               device if it needs re-registering
         Returns:
               The access token for the user's session.
         Raises:
             StoreError if there was a problem storing the token.
-            LoginError if there was an authentication problem.
         """
         logger.info("Logging in user %s on device %s", user_id, device_id)
         access_token = yield self.issue_access_token(user_id, device_id)
@@ -420,9 +525,11 @@ class AuthHandler(BaseHandler):
         # really don't want is active access_tokens without a record of the
         # device, so we double-check it here.
         if device_id is not None:
-            yield self.device_handler.check_device_registered(
-                user_id, device_id, initial_display_name
-            )
+            try:
+                yield self.store.get_device(user_id, device_id)
+            except StoreError:
+                yield self.store.delete_access_token(access_token)
+                raise StoreError(400, "Login raced against device deletion")
 
         defer.returnValue(access_token)
 
@@ -474,29 +581,115 @@ class AuthHandler(BaseHandler):
             )
         defer.returnValue(result)
 
+    def get_supported_login_types(self):
+        """Get a the login types supported for the /login API
+
+        By default this is just 'm.login.password' (unless password_enabled is
+        False in the config file), but password auth providers can provide
+        other login types.
+
+        Returns:
+            Iterable[str]: login types
+        """
+        return self._supported_login_types
+
     @defer.inlineCallbacks
-    def _check_password(self, user_id, password):
-        """Authenticate a user against the LDAP and local databases.
+    def validate_login(self, username, login_submission):
+        """Authenticates the user for the /login API
 
-        user_id is checked case insensitively against the local database, but
-        will throw if there are multiple inexact matches.
+        Also used by the user-interactive auth flow to validate
+        m.login.password auth types.
 
         Args:
-            user_id (str): complete @user:id
+            username (str): username supplied by the user
+            login_submission (dict): the whole of the login submission
+                (including 'type' and other relevant fields)
         Returns:
-            (str) the canonical_user_id
+            Deferred[str, func]: canonical user id, and optional callback
+                to be called once the access token and device id are issued
         Raises:
-            LoginError if login fails
+            StoreError if there was a problem accessing the database
+            SynapseError if there was a problem with the request
+            LoginError if there was an authentication problem.
         """
+
+        if username.startswith('@'):
+            qualified_user_id = username
+        else:
+            qualified_user_id = UserID(
+                username, self.hs.hostname
+            ).to_string()
+
+        login_type = login_submission.get("type")
+        known_login_type = False
+
+        # special case to check for "password" for the check_password interface
+        # for the auth providers
+        password = login_submission.get("password")
+        if login_type == LoginType.PASSWORD:
+            if not self._password_enabled:
+                raise SynapseError(400, "Password login has been disabled.")
+            if not password:
+                raise SynapseError(400, "Missing parameter: password")
+
         for provider in self.password_providers:
-            is_valid = yield provider.check_password(user_id, password)
-            if is_valid:
-                defer.returnValue(user_id)
+            if (hasattr(provider, "check_password")
+                    and login_type == LoginType.PASSWORD):
+                known_login_type = True
+                is_valid = yield provider.check_password(
+                    qualified_user_id, password,
+                )
+                if is_valid:
+                    defer.returnValue((qualified_user_id, None))
+
+            if (not hasattr(provider, "get_supported_login_types")
+                    or not hasattr(provider, "check_auth")):
+                # this password provider doesn't understand custom login types
+                continue
+
+            supported_login_types = provider.get_supported_login_types()
+            if login_type not in supported_login_types:
+                # this password provider doesn't understand this login type
+                continue
+
+            known_login_type = True
+            login_fields = supported_login_types[login_type]
+
+            missing_fields = []
+            login_dict = {}
+            for f in login_fields:
+                if f not in login_submission:
+                    missing_fields.append(f)
+                else:
+                    login_dict[f] = login_submission[f]
+            if missing_fields:
+                raise SynapseError(
+                    400, "Missing parameters for login type %s: %s" % (
+                        login_type,
+                        missing_fields,
+                    ),
+                )
 
-        canonical_user_id = yield self._check_local_password(user_id, password)
+            result = yield provider.check_auth(
+                username, login_type, login_dict,
+            )
+            if result:
+                if isinstance(result, str):
+                    result = (result, None)
+                defer.returnValue(result)
+
+        if login_type == LoginType.PASSWORD:
+            known_login_type = True
+
+            canonical_user_id = yield self._check_local_password(
+                qualified_user_id, password,
+            )
 
-        if canonical_user_id:
-            defer.returnValue(canonical_user_id)
+            if canonical_user_id:
+                defer.returnValue((canonical_user_id, None))
+
+        if not known_login_type:
+            raise SynapseError(400, "Unknown login type %s" % login_type)
 
         # unknown username or invalid password. We raise a 403 here, but note
         # that if we're doing user-interactive login, it turns all LoginErrors
@@ -522,7 +715,7 @@ class AuthHandler(BaseHandler):
         if not lookupres:
             defer.returnValue(None)
         (user_id, password_hash) = lookupres
-        result = self.validate_hash(password, password_hash)
+        result = yield self.validate_hash(password, password_hash)
         if not result:
             logger.warn("Failed password login for user %s", user_id)
             defer.returnValue(None)
@@ -546,22 +739,65 @@ class AuthHandler(BaseHandler):
             raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
 
     @defer.inlineCallbacks
-    def set_password(self, user_id, newpassword, requester=None):
-        password_hash = self.hash(newpassword)
+    def delete_access_token(self, access_token):
+        """Invalidate a single access token
 
-        except_access_token_id = requester.access_token_id if requester else None
+        Args:
+            access_token (str): access token to be deleted
 
-        try:
-            yield self.store.user_set_password_hash(user_id, password_hash)
-        except StoreError as e:
-            if e.code == 404:
-                raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
-            raise e
-        yield self.store.user_delete_access_tokens(
-            user_id, except_access_token_id
+        Returns:
+            Deferred
+        """
+        user_info = yield self.auth.get_user_by_access_token(access_token)
+        yield self.store.delete_access_token(access_token)
+
+        # see if any of our auth providers want to know about this
+        for provider in self.password_providers:
+            if hasattr(provider, "on_logged_out"):
+                yield provider.on_logged_out(
+                    user_id=str(user_info["user"]),
+                    device_id=user_info["device_id"],
+                    access_token=access_token,
+                )
+
+        # delete pushers associated with this access token
+        if user_info["token_id"] is not None:
+            yield self.hs.get_pusherpool().remove_pushers_by_access_token(
+                str(user_info["user"]), (user_info["token_id"], )
+            )
+
+    @defer.inlineCallbacks
+    def delete_access_tokens_for_user(self, user_id, except_token_id=None,
+                                      device_id=None):
+        """Invalidate access tokens belonging to a user
+
+        Args:
+            user_id (str):  ID of user the tokens belong to
+            except_token_id (str|None): access_token ID which should *not* be
+                deleted
+            device_id (str|None):  ID of device the tokens are associated with.
+                If None, tokens associated with any device (or no device) will
+                be deleted
+        Returns:
+            Deferred
+        """
+        tokens_and_devices = yield self.store.user_delete_access_tokens(
+            user_id, except_token_id=except_token_id, device_id=device_id,
         )
-        yield self.hs.get_pusherpool().remove_pushers_by_user(
-            user_id, except_access_token_id
+
+        # see if any of our auth providers want to know about this
+        for provider in self.password_providers:
+            if hasattr(provider, "on_logged_out"):
+                for token, token_id, device_id in tokens_and_devices:
+                    yield provider.on_logged_out(
+                        user_id=user_id,
+                        device_id=device_id,
+                        access_token=token,
+                    )
+
+        # delete pushers associated with the access tokens
+        yield self.hs.get_pusherpool().remove_pushers_by_access_token(
+            user_id, (token_id for _, token_id, _ in tokens_and_devices),
         )
 
     @defer.inlineCallbacks
@@ -599,16 +835,6 @@ class AuthHandler(BaseHandler):
         logger.debug("Saving session %s", session)
         session["last_used"] = self.hs.get_clock().time_msec()
         self.sessions[session["id"]] = session
-        self._prune_sessions()
-
-    def _prune_sessions(self):
-        for sid, sess in self.sessions.items():
-            last_used = 0
-            if 'last_used' in sess:
-                last_used = sess['last_used']
-            now = self.hs.get_clock().time_msec()
-            if last_used < now - AuthHandler.SESSION_EXPIRE_MS:
-                del self.sessions[sid]
 
     def hash(self, password):
         """Computes a secure hash of password.
@@ -617,10 +843,13 @@ class AuthHandler(BaseHandler):
             password (str): Password to hash.
 
         Returns:
-            Hashed password (str).
+            Deferred(str): Hashed password.
         """
-        return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
-                             bcrypt.gensalt(self.bcrypt_rounds))
+        def _do_hash():
+            return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
+                                 bcrypt.gensalt(self.bcrypt_rounds))
+
+        return make_deferred_yieldable(threads.deferToThread(_do_hash))
 
     def validate_hash(self, password, stored_hash):
         """Validates that self.hash(password) == stored_hash.
@@ -630,13 +859,19 @@ class AuthHandler(BaseHandler):
             stored_hash (str): Expected hash value.
 
         Returns:
-            Whether self.hash(password) == stored_hash (bool).
+            Deferred(bool): Whether self.hash(password) == stored_hash.
         """
+
+        def _do_validate_hash():
+            return bcrypt.checkpw(
+                password.encode('utf8') + self.hs.config.password_pepper,
+                stored_hash.encode('utf8')
+            )
+
         if stored_hash:
-            return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
-                                 stored_hash.encode('utf8')) == stored_hash
+            return make_deferred_yieldable(threads.deferToThread(_do_validate_hash))
         else:
-            return False
+            return defer.succeed(False)
 
 
 class MacaroonGeneartor(object):
@@ -679,30 +914,3 @@ class MacaroonGeneartor(object):
         macaroon.add_first_party_caveat("gen = 1")
         macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
         return macaroon
-
-
-class _AccountHandler(object):
-    """A proxy object that gets passed to password auth providers so they
-    can register new users etc if necessary.
-    """
-    def __init__(self, hs, check_user_exists):
-        self.hs = hs
-
-        self._check_user_exists = check_user_exists
-
-    def check_user_exists(self, user_id):
-        """Check if user exissts.
-
-        Returns:
-            Deferred(bool)
-        """
-        return self._check_user_exists(user_id)
-
-    def register(self, localpart):
-        """Registers a new user with given localpart
-
-        Returns:
-            Deferred: a 2-tuple of (user_id, access_token)
-        """
-        reg = self.hs.get_handlers().registration_handler
-        return reg.register(localpart=localpart)
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
new file mode 100644
index 0000000000..b1d3814909
--- /dev/null
+++ b/synapse/handlers/deactivate_account.py
@@ -0,0 +1,52 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from ._base import BaseHandler
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class DeactivateAccountHandler(BaseHandler):
+    """Handler which deals with deactivating user accounts."""
+    def __init__(self, hs):
+        super(DeactivateAccountHandler, self).__init__(hs)
+        self._auth_handler = hs.get_auth_handler()
+        self._device_handler = hs.get_device_handler()
+
+    @defer.inlineCallbacks
+    def deactivate_account(self, user_id):
+        """Deactivate a user's account
+
+        Args:
+            user_id (str): ID of user to be deactivated
+
+        Returns:
+            Deferred
+        """
+        # FIXME: Theoretically there is a race here wherein user resets
+        # password using threepid.
+
+        # first delete any devices belonging to the user, which will also
+        # delete corresponding access tokens.
+        yield self._device_handler.delete_all_devices_for_user(user_id)
+        # then delete any remaining access tokens which weren't associated with
+        # a device.
+        yield self._auth_handler.delete_access_tokens_for_user(user_id)
+
+        yield self.store.user_delete_threepids(user_id)
+        yield self.store.user_set_password_hash(user_id, None)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index ca7137f315..f7457a7082 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -14,8 +14,11 @@
 # limitations under the License.
 from synapse.api import errors
 from synapse.api.constants import EventTypes
+from synapse.api.errors import FederationDeniedError
 from synapse.util import stringutils
 from synapse.util.async import Linearizer
+from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.retryutils import NotRetryingDestination
 from synapse.util.metrics import measure_func
 from synapse.types import get_domain_from_id, RoomStreamToken
 from twisted.internet import defer
@@ -32,14 +35,17 @@ class DeviceHandler(BaseHandler):
 
         self.hs = hs
         self.state = hs.get_state_handler()
+        self._auth_handler = hs.get_auth_handler()
         self.federation_sender = hs.get_federation_sender()
-        self.federation = hs.get_replication_layer()
-        self._remote_edue_linearizer = Linearizer(name="remote_device_list")
 
-        self.federation.register_edu_handler(
-            "m.device_list_update", self._incoming_device_list_update,
+        self._edu_updater = DeviceListEduUpdater(hs, self)
+
+        federation_registry = hs.get_federation_registry()
+
+        federation_registry.register_edu_handler(
+            "m.device_list_update", self._edu_updater.incoming_device_list_update,
         )
-        self.federation.register_query_handler(
+        federation_registry.register_query_handler(
             "user_devices", self.on_federation_query_user_devices,
         )
 
@@ -103,7 +109,7 @@ class DeviceHandler(BaseHandler):
         device_map = yield self.store.get_devices_by_user(user_id)
 
         ips = yield self.store.get_last_client_ip_by_device(
-            devices=((user_id, device_id) for device_id in device_map.keys())
+            user_id, device_id=None
         )
 
         devices = device_map.values()
@@ -130,7 +136,7 @@ class DeviceHandler(BaseHandler):
         except errors.StoreError:
             raise errors.NotFoundError
         ips = yield self.store.get_last_client_ip_by_device(
-            devices=((user_id, device_id),)
+            user_id, device_id,
         )
         _update_device_from_client_ips(device, ips)
         defer.returnValue(device)
@@ -149,16 +155,15 @@ class DeviceHandler(BaseHandler):
 
         try:
             yield self.store.delete_device(user_id, device_id)
-        except errors.StoreError, e:
+        except errors.StoreError as e:
             if e.code == 404:
                 # no match
                 pass
             else:
                 raise
 
-        yield self.store.user_delete_access_tokens(
+        yield self._auth_handler.delete_access_tokens_for_user(
             user_id, device_id=device_id,
-            delete_refresh_tokens=True,
         )
 
         yield self.store.delete_e2e_keys_by_device(
@@ -168,6 +173,57 @@ class DeviceHandler(BaseHandler):
         yield self.notify_device_update(user_id, [device_id])
 
     @defer.inlineCallbacks
+    def delete_all_devices_for_user(self, user_id, except_device_id=None):
+        """Delete all of the user's devices
+
+        Args:
+            user_id (str):
+            except_device_id (str|None): optional device id which should not
+                be deleted
+
+        Returns:
+            defer.Deferred:
+        """
+        device_map = yield self.store.get_devices_by_user(user_id)
+        device_ids = device_map.keys()
+        if except_device_id is not None:
+            device_ids = [d for d in device_ids if d != except_device_id]
+        yield self.delete_devices(user_id, device_ids)
+
+    @defer.inlineCallbacks
+    def delete_devices(self, user_id, device_ids):
+        """ Delete several devices
+
+        Args:
+            user_id (str):
+            device_ids (List[str]): The list of device IDs to delete
+
+        Returns:
+            defer.Deferred:
+        """
+
+        try:
+            yield self.store.delete_devices(user_id, device_ids)
+        except errors.StoreError as e:
+            if e.code == 404:
+                # no match
+                pass
+            else:
+                raise
+
+        # Delete access tokens and e2e keys for each device. Not optimised as it is not
+        # considered as part of a critical path.
+        for device_id in device_ids:
+            yield self._auth_handler.delete_access_tokens_for_user(
+                user_id, device_id=device_id,
+            )
+            yield self.store.delete_e2e_keys_by_device(
+                user_id=user_id, device_id=device_id
+            )
+
+        yield self.notify_device_update(user_id, device_ids)
+
+    @defer.inlineCallbacks
     def update_device(self, user_id, device_id, content):
         """ Update the given device
 
@@ -187,7 +243,7 @@ class DeviceHandler(BaseHandler):
                 new_display_name=content.get("display_name")
             )
             yield self.notify_device_update(user_id, [device_id])
-        except errors.StoreError, e:
+        except errors.StoreError as e:
             if e.code == 404:
                 raise errors.NotFoundError()
             else:
@@ -212,8 +268,7 @@ class DeviceHandler(BaseHandler):
             user_id, device_ids, list(hosts)
         )
 
-        rooms = yield self.store.get_rooms_for_user(user_id)
-        room_ids = [r.room_id for r in rooms]
+        room_ids = yield self.store.get_rooms_for_user(user_id)
 
         yield self.notifier.on_new_event(
             "device_list_key", position, rooms=room_ids,
@@ -234,8 +289,9 @@ class DeviceHandler(BaseHandler):
             user_id (str)
             from_token (StreamToken)
         """
-        rooms = yield self.store.get_rooms_for_user(user_id)
-        room_ids = set(r.room_id for r in rooms)
+        now_token = yield self.hs.get_event_sources().get_current_token()
+
+        room_ids = yield self.store.get_rooms_for_user(user_id)
 
         # First we check if any devices have changed
         changed = yield self.store.get_user_whose_devices_changed(
@@ -245,11 +301,30 @@ class DeviceHandler(BaseHandler):
         # Then work out if any users have since joined
         rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
 
+        member_events = yield self.store.get_membership_changes_for_user(
+            user_id, from_token.room_key, now_token.room_key
+        )
+        rooms_changed.update(event.room_id for event in member_events)
+
         stream_ordering = RoomStreamToken.parse_stream_token(
-            from_token.room_key).stream
+            from_token.room_key
+        ).stream
 
         possibly_changed = set(changed)
+        possibly_left = set()
         for room_id in rooms_changed:
+            current_state_ids = yield self.store.get_current_state_ids(room_id)
+
+            # The user may have left the room
+            # TODO: Check if they actually did or if we were just invited.
+            if room_id not in room_ids:
+                for key, event_id in current_state_ids.iteritems():
+                    etype, state_key = key
+                    if etype != EventTypes.Member:
+                        continue
+                    possibly_left.add(state_key)
+                continue
+
             # Fetch the current state at the time.
             try:
                 event_ids = yield self.store.get_forward_extremeties_for_room(
@@ -260,8 +335,6 @@ class DeviceHandler(BaseHandler):
                 # ordering: treat it the same as a new room
                 event_ids = []
 
-            current_state_ids = yield self.state.get_current_state_ids(room_id)
-
             # special-case for an empty prev state: include all members
             # in the changed list
             if not event_ids:
@@ -272,9 +345,25 @@ class DeviceHandler(BaseHandler):
                     possibly_changed.add(state_key)
                 continue
 
+            current_member_id = current_state_ids.get((EventTypes.Member, user_id))
+            if not current_member_id:
+                continue
+
             # mapping from event_id -> state_dict
             prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
 
+            # Check if we've joined the room? If so we just blindly add all the users to
+            # the "possibly changed" users.
+            for state_dict in prev_state_ids.itervalues():
+                member_event = state_dict.get((EventTypes.Member, user_id), None)
+                if not member_event or member_event != current_member_id:
+                    for key, event_id in current_state_ids.iteritems():
+                        etype, state_key = key
+                        if etype != EventTypes.Member:
+                            continue
+                        possibly_changed.add(state_key)
+                    break
+
             # If there has been any change in membership, include them in the
             # possibly changed list. We'll check if they are joined below,
             # and we're not toooo worried about spuriously adding users.
@@ -285,94 +374,208 @@ class DeviceHandler(BaseHandler):
 
                 # check if this member has changed since any of the extremities
                 # at the stream_ordering, and add them to the list if so.
-                for state_dict in prev_state_ids.values():
+                for state_dict in prev_state_ids.itervalues():
                     prev_event_id = state_dict.get(key, None)
                     if not prev_event_id or prev_event_id != event_id:
-                        possibly_changed.add(state_key)
+                        if state_key != user_id:
+                            possibly_changed.add(state_key)
                         break
 
-        users_who_share_room = yield self.store.get_users_who_share_room_with_user(
-            user_id
-        )
+        if possibly_changed or possibly_left:
+            users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+                user_id
+            )
 
-        # Take the intersection of the users whose devices may have changed
-        # and those that actually still share a room with the user
-        defer.returnValue(users_who_share_room & possibly_changed)
+            # Take the intersection of the users whose devices may have changed
+            # and those that actually still share a room with the user
+            possibly_joined = possibly_changed & users_who_share_room
+            possibly_left = (possibly_changed | possibly_left) - users_who_share_room
+        else:
+            possibly_joined = []
+            possibly_left = []
+
+        defer.returnValue({
+            "changed": list(possibly_joined),
+            "left": list(possibly_left),
+        })
 
-    @measure_func("_incoming_device_list_update")
     @defer.inlineCallbacks
-    def _incoming_device_list_update(self, origin, edu_content):
-        user_id = edu_content["user_id"]
-        device_id = edu_content["device_id"]
-        stream_id = edu_content["stream_id"]
-        prev_ids = edu_content.get("prev_id", [])
+    def on_federation_query_user_devices(self, user_id):
+        stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
+        defer.returnValue({
+            "user_id": user_id,
+            "stream_id": stream_id,
+            "devices": devices,
+        })
+
+    @defer.inlineCallbacks
+    def user_left_room(self, user, room_id):
+        user_id = user.to_string()
+        room_ids = yield self.store.get_rooms_for_user(user_id)
+        if not room_ids:
+            # We no longer share rooms with this user, so we'll no longer
+            # receive device updates. Mark this in DB.
+            yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
+
+
+def _update_device_from_client_ips(device, client_ips):
+    ip = client_ips.get((device["user_id"], device["device_id"]), {})
+    device.update({
+        "last_seen_ts": ip.get("last_seen"),
+        "last_seen_ip": ip.get("ip"),
+    })
+
+
+class DeviceListEduUpdater(object):
+    "Handles incoming device list updates from federation and updates the DB"
+
+    def __init__(self, hs, device_handler):
+        self.store = hs.get_datastore()
+        self.federation = hs.get_federation_client()
+        self.clock = hs.get_clock()
+        self.device_handler = device_handler
+
+        self._remote_edu_linearizer = Linearizer(name="remote_device_list")
+
+        # user_id -> list of updates waiting to be handled.
+        self._pending_updates = {}
+
+        # Recently seen stream ids. We don't bother keeping these in the DB,
+        # but they're useful to have them about to reduce the number of spurious
+        # resyncs.
+        self._seen_updates = ExpiringCache(
+            cache_name="device_update_edu",
+            clock=self.clock,
+            max_len=10000,
+            expiry_ms=30 * 60 * 1000,
+            iterable=True,
+        )
+
+    @defer.inlineCallbacks
+    def incoming_device_list_update(self, origin, edu_content):
+        """Called on incoming device list update from federation. Responsible
+        for parsing the EDU and adding to pending updates list.
+        """
+
+        user_id = edu_content.pop("user_id")
+        device_id = edu_content.pop("device_id")
+        stream_id = str(edu_content.pop("stream_id"))  # They may come as ints
+        prev_ids = edu_content.pop("prev_id", [])
+        prev_ids = [str(p) for p in prev_ids]   # They may come as ints
 
         if get_domain_from_id(user_id) != origin:
             # TODO: Raise?
             logger.warning("Got device list update edu for %r from %r", user_id, origin)
             return
 
-        rooms = yield self.store.get_rooms_for_user(user_id)
-        if not rooms:
+        room_ids = yield self.store.get_rooms_for_user(user_id)
+        if not room_ids:
             # We don't share any rooms with this user. Ignore update, as we
             # probably won't get any further updates.
             return
 
-        with (yield self._remote_edue_linearizer.queue(user_id)):
-            # If the prev id matches whats in our cache table, then we don't need
-            # to resync the users device list, otherwise we do.
-            resync = True
-            if len(prev_ids) == 1:
-                extremity = yield self.store.get_device_list_last_stream_id_for_remote(
-                    user_id
-                )
-                logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids)
-                if str(extremity) == str(prev_ids[0]):
-                    resync = False
+        self._pending_updates.setdefault(user_id, []).append(
+            (device_id, stream_id, prev_ids, edu_content)
+        )
+
+        yield self._handle_device_updates(user_id)
+
+    @measure_func("_incoming_device_list_update")
+    @defer.inlineCallbacks
+    def _handle_device_updates(self, user_id):
+        "Actually handle pending updates."
+
+        with (yield self._remote_edu_linearizer.queue(user_id)):
+            pending_updates = self._pending_updates.pop(user_id, [])
+            if not pending_updates:
+                # This can happen since we batch updates
+                return
+
+            # Given a list of updates we check if we need to resync. This
+            # happens if we've missed updates.
+            resync = yield self._need_to_do_resync(user_id, pending_updates)
 
             if resync:
                 # Fetch all devices for the user.
-                result = yield self.federation.query_user_devices(origin, user_id)
+                origin = get_domain_from_id(user_id)
+                try:
+                    result = yield self.federation.query_user_devices(origin, user_id)
+                except NotRetryingDestination:
+                    # TODO: Remember that we are now out of sync and try again
+                    # later
+                    logger.warn(
+                        "Failed to handle device list update for %s,"
+                        " we're not retrying the remote",
+                        user_id,
+                    )
+                    # We abort on exceptions rather than accepting the update
+                    # as otherwise synapse will 'forget' that its device list
+                    # is out of date. If we bail then we will retry the resync
+                    # next time we get a device list update for this user_id.
+                    # This makes it more likely that the device lists will
+                    # eventually become consistent.
+                    return
+                except FederationDeniedError as e:
+                    logger.info(e)
+                    return
+                except Exception:
+                    # TODO: Remember that we are now out of sync and try again
+                    # later
+                    logger.exception(
+                        "Failed to handle device list update for %s", user_id
+                    )
+                    return
+
                 stream_id = result["stream_id"]
                 devices = result["devices"]
                 yield self.store.update_remote_device_list_cache(
                     user_id, devices, stream_id,
                 )
                 device_ids = [device["device_id"] for device in devices]
-                yield self.notify_device_update(user_id, device_ids)
+                yield self.device_handler.notify_device_update(user_id, device_ids)
             else:
                 # Simply update the single device, since we know that is the only
                 # change (becuase of the single prev_id matching the current cache)
-                content = dict(edu_content)
-                for key in ("user_id", "device_id", "stream_id", "prev_ids"):
-                    content.pop(key, None)
-                yield self.store.update_remote_device_list_cache_entry(
-                    user_id, device_id, content, stream_id,
+                for device_id, stream_id, prev_ids, content in pending_updates:
+                    yield self.store.update_remote_device_list_cache_entry(
+                        user_id, device_id, content, stream_id,
+                    )
+
+                yield self.device_handler.notify_device_update(
+                    user_id, [device_id for device_id, _, _, _ in pending_updates]
                 )
-                yield self.notify_device_update(user_id, [device_id])
 
-    @defer.inlineCallbacks
-    def on_federation_query_user_devices(self, user_id):
-        stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
-        defer.returnValue({
-            "user_id": user_id,
-            "stream_id": stream_id,
-            "devices": devices,
-        })
+            self._seen_updates.setdefault(user_id, set()).update(
+                stream_id for _, stream_id, _, _ in pending_updates
+            )
 
     @defer.inlineCallbacks
-    def user_left_room(self, user, room_id):
-        user_id = user.to_string()
-        rooms = yield self.store.get_rooms_for_user(user_id)
-        if not rooms:
-            # We no longer share rooms with this user, so we'll no longer
-            # receive device updates. Mark this in DB.
-            yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
+    def _need_to_do_resync(self, user_id, updates):
+        """Given a list of updates for a user figure out if we need to do a full
+        resync, or whether we have enough data that we can just apply the delta.
+        """
+        seen_updates = self._seen_updates.get(user_id, set())
 
+        extremity = yield self.store.get_device_list_last_stream_id_for_remote(
+            user_id
+        )
 
-def _update_device_from_client_ips(device, client_ips):
-    ip = client_ips.get((device["user_id"], device["device_id"]), {})
-    device.update({
-        "last_seen_ts": ip.get("last_seen"),
-        "last_seen_ip": ip.get("ip"),
-    })
+        stream_id_in_updates = set()  # stream_ids in updates list
+        for _, stream_id, prev_ids, _ in updates:
+            if not prev_ids:
+                # We always do a resync if there are no previous IDs
+                defer.returnValue(True)
+
+            for prev_id in prev_ids:
+                if prev_id == extremity:
+                    continue
+                elif prev_id in seen_updates:
+                    continue
+                elif prev_id in stream_id_in_updates:
+                    continue
+                else:
+                    defer.returnValue(True)
+
+            stream_id_in_updates.add(stream_id)
+
+        defer.returnValue(False)
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index f7fad15c62..f147a20b73 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -17,7 +17,8 @@ import logging
 
 from twisted.internet import defer
 
-from synapse.types import get_domain_from_id
+from synapse.api.errors import SynapseError
+from synapse.types import get_domain_from_id, UserID
 from synapse.util.stringutils import random_string
 
 
@@ -33,10 +34,10 @@ class DeviceMessageHandler(object):
         """
         self.store = hs.get_datastore()
         self.notifier = hs.get_notifier()
-        self.is_mine_id = hs.is_mine_id
+        self.is_mine = hs.is_mine
         self.federation = hs.get_federation_sender()
 
-        hs.get_replication_layer().register_edu_handler(
+        hs.get_federation_registry().register_edu_handler(
             "m.direct_to_device", self.on_direct_to_device_edu
         )
 
@@ -52,6 +53,12 @@ class DeviceMessageHandler(object):
         message_type = content["type"]
         message_id = content["message_id"]
         for user_id, by_device in content["messages"].items():
+            # we use UserID.from_string to catch invalid user ids
+            if not self.is_mine(UserID.from_string(user_id)):
+                logger.warning("Request for keys for non-local user %s",
+                               user_id)
+                raise SynapseError(400, "Not a user here")
+
             messages_by_device = {
                 device_id: {
                     "content": message_content,
@@ -77,7 +84,8 @@ class DeviceMessageHandler(object):
         local_messages = {}
         remote_messages = {}
         for user_id, by_device in messages.items():
-            if self.is_mine_id(user_id):
+            # we use UserID.from_string to catch invalid user ids
+            if self.is_mine(UserID.from_string(user_id)):
                 messages_by_device = {
                     device_id: {
                         "content": message_content,
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 1b5317edf5..c5b6e75e03 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -34,12 +34,15 @@ class DirectoryHandler(BaseHandler):
 
         self.state = hs.get_state_handler()
         self.appservice_handler = hs.get_application_service_handler()
+        self.event_creation_handler = hs.get_event_creation_handler()
 
-        self.federation = hs.get_replication_layer()
-        self.federation.register_query_handler(
+        self.federation = hs.get_federation_client()
+        hs.get_federation_registry().register_query_handler(
             "directory", self.on_directory_query
         )
 
+        self.spam_checker = hs.get_spam_checker()
+
     @defer.inlineCallbacks
     def _create_association(self, room_alias, room_id, servers=None, creator=None):
         # general association creation for both human users and app services
@@ -73,6 +76,11 @@ class DirectoryHandler(BaseHandler):
         # association creation for human users
         # TODO(erikj): Do user auth.
 
+        if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
+            raise SynapseError(
+                403, "This user is not permitted to create this alias",
+            )
+
         can_create = yield self.can_modify_alias(
             room_alias,
             user_id=user_id
@@ -175,6 +183,7 @@ class DirectoryHandler(BaseHandler):
                         "room_alias": room_alias.to_string(),
                     },
                     retry_on_dns_fail=False,
+                    ignore_backoff=True,
                 )
             except CodeMessageException as e:
                 logging.warn("Error retrieving alias")
@@ -241,8 +250,7 @@ class DirectoryHandler(BaseHandler):
     def send_room_alias_update_event(self, requester, user_id, room_id):
         aliases = yield self.store.get_aliases_for_room(room_id)
 
-        msg_handler = self.hs.get_handlers().message_handler
-        yield msg_handler.create_and_send_nonmember_event(
+        yield self.event_creation_handler.create_and_send_nonmember_event(
             requester,
             {
                 "type": EventTypes.Aliases,
@@ -264,8 +272,7 @@ class DirectoryHandler(BaseHandler):
         if not alias_event or alias_event.content.get("alias", "") != alias_str:
             return
 
-        msg_handler = self.hs.get_handlers().message_handler
-        yield msg_handler.create_and_send_nonmember_event(
+        yield self.event_creation_handler.create_and_send_nonmember_event(
             requester,
             {
                 "type": EventTypes.CanonicalAlias,
@@ -326,6 +333,14 @@ class DirectoryHandler(BaseHandler):
         room_id (str)
         visibility (str): "public" or "private"
         """
+        if not self.spam_checker.user_may_publish_room(
+            requester.user.to_string(), room_id
+        ):
+            raise AuthError(
+                403,
+                "This user is not permitted to publish rooms to the room list"
+            )
+
         if requester.is_guest:
             raise AuthError(403, "Guests cannot edit the published room list")
 
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index e40495d1ab..25aec624af 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,16 +14,18 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import ujson as json
+import simplejson as json
 import logging
 
 from canonicaljson import encode_canonical_json
 from twisted.internet import defer
 
-from synapse.api.errors import SynapseError, CodeMessageException
-from synapse.types import get_domain_from_id
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
+from synapse.api.errors import (
+    SynapseError, CodeMessageException, FederationDeniedError,
+)
+from synapse.types import get_domain_from_id, UserID
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+from synapse.util.retryutils import NotRetryingDestination
 
 logger = logging.getLogger(__name__)
 
@@ -30,15 +33,15 @@ logger = logging.getLogger(__name__)
 class E2eKeysHandler(object):
     def __init__(self, hs):
         self.store = hs.get_datastore()
-        self.federation = hs.get_replication_layer()
+        self.federation = hs.get_federation_client()
         self.device_handler = hs.get_device_handler()
-        self.is_mine_id = hs.is_mine_id
+        self.is_mine = hs.is_mine
         self.clock = hs.get_clock()
 
         # doesn't really work as part of the generic query API, because the
         # query request requires an object POST, but we abuse the
         # "query handler" interface.
-        self.federation.register_query_handler(
+        hs.get_federation_registry().register_query_handler(
             "client_keys", self.on_federation_query_client_keys
         )
 
@@ -70,7 +73,8 @@ class E2eKeysHandler(object):
         remote_queries = {}
 
         for user_id, device_ids in device_keys_query.items():
-            if self.is_mine_id(user_id):
+            # we use UserID.from_string to catch invalid user ids
+            if self.is_mine(UserID.from_string(user_id)):
                 local_query[user_id] = device_ids
             else:
                 remote_queries[user_id] = device_ids
@@ -121,38 +125,23 @@ class E2eKeysHandler(object):
         def do_remote_query(destination):
             destination_query = remote_queries_not_in_cache[destination]
             try:
-                limiter = yield get_retry_limiter(
-                    destination, self.clock, self.store
+                remote_result = yield self.federation.query_client_keys(
+                    destination,
+                    {"device_keys": destination_query},
+                    timeout=timeout
                 )
-                with limiter:
-                    remote_result = yield self.federation.query_client_keys(
-                        destination,
-                        {"device_keys": destination_query},
-                        timeout=timeout
-                    )
 
                 for user_id, keys in remote_result["device_keys"].items():
                     if user_id in destination_query:
                         results[user_id] = keys
 
-            except CodeMessageException as e:
-                failures[destination] = {
-                    "status": e.code, "message": e.message
-                }
-            except NotRetryingDestination as e:
-                failures[destination] = {
-                    "status": 503, "message": "Not ready for retry",
-                }
             except Exception as e:
-                # include ConnectionRefused and other errors
-                failures[destination] = {
-                    "status": 503, "message": e.message
-                }
+                failures[destination] = _exception_to_failure(e)
 
-        yield preserve_context_over_deferred(defer.gatherResults([
-            preserve_fn(do_remote_query)(destination)
+        yield make_deferred_yieldable(defer.gatherResults([
+            run_in_background(do_remote_query, destination)
             for destination in remote_queries_not_in_cache
-        ]))
+        ], consumeErrors=True))
 
         defer.returnValue({
             "device_keys": results, "failures": failures,
@@ -174,7 +163,8 @@ class E2eKeysHandler(object):
 
         result_dict = {}
         for user_id, device_ids in query.items():
-            if not self.is_mine_id(user_id):
+            # we use UserID.from_string to catch invalid user ids
+            if not self.is_mine(UserID.from_string(user_id)):
                 logger.warning("Request for keys for non-local user %s",
                                user_id)
                 raise SynapseError(400, "Not a user here")
@@ -217,7 +207,8 @@ class E2eKeysHandler(object):
         remote_queries = {}
 
         for user_id, device_keys in query.get("one_time_keys", {}).items():
-            if self.is_mine_id(user_id):
+            # we use UserID.from_string to catch invalid user ids
+            if self.is_mine(UserID.from_string(user_id)):
                 for device_id, algorithm in device_keys.items():
                     local_query.append((user_id, device_id, algorithm))
             else:
@@ -239,36 +230,31 @@ class E2eKeysHandler(object):
         def claim_client_keys(destination):
             device_keys = remote_queries[destination]
             try:
-                limiter = yield get_retry_limiter(
-                    destination, self.clock, self.store
+                remote_result = yield self.federation.claim_client_keys(
+                    destination,
+                    {"one_time_keys": device_keys},
+                    timeout=timeout
                 )
-                with limiter:
-                    remote_result = yield self.federation.claim_client_keys(
-                        destination,
-                        {"one_time_keys": device_keys},
-                        timeout=timeout
-                    )
-                    for user_id, keys in remote_result["one_time_keys"].items():
-                        if user_id in device_keys:
-                            json_result[user_id] = keys
-            except CodeMessageException as e:
-                failures[destination] = {
-                    "status": e.code, "message": e.message
-                }
-            except NotRetryingDestination as e:
-                failures[destination] = {
-                    "status": 503, "message": "Not ready for retry",
-                }
+                for user_id, keys in remote_result["one_time_keys"].items():
+                    if user_id in device_keys:
+                        json_result[user_id] = keys
             except Exception as e:
-                # include ConnectionRefused and other errors
-                failures[destination] = {
-                    "status": 503, "message": e.message
-                }
+                failures[destination] = _exception_to_failure(e)
 
-        yield preserve_context_over_deferred(defer.gatherResults([
-            preserve_fn(claim_client_keys)(destination)
+        yield make_deferred_yieldable(defer.gatherResults([
+            run_in_background(claim_client_keys, destination)
             for destination in remote_queries
-        ]))
+        ], consumeErrors=True))
+
+        logger.info(
+            "Claimed one-time-keys: %s",
+            ",".join((
+                "%s for %s:%s" % (key_id, user_id, device_id)
+                for user_id, user_keys in json_result.iteritems()
+                for device_id, device_keys in user_keys.iteritems()
+                for key_id, _ in device_keys.iteritems()
+            )),
+        )
 
         defer.returnValue({
             "one_time_keys": json_result,
@@ -296,19 +282,8 @@ class E2eKeysHandler(object):
 
         one_time_keys = keys.get("one_time_keys", None)
         if one_time_keys:
-            logger.info(
-                "Adding %d one_time_keys for device %r for user %r at %d",
-                len(one_time_keys), device_id, user_id, time_now
-            )
-            key_list = []
-            for key_id, key_json in one_time_keys.items():
-                algorithm, key_id = key_id.split(":")
-                key_list.append((
-                    algorithm, key_id, encode_canonical_json(key_json)
-                ))
-
-            yield self.store.add_e2e_one_time_keys(
-                user_id, device_id, time_now, key_list
+            yield self._upload_one_time_keys_for_user(
+                user_id, device_id, time_now, one_time_keys,
             )
 
         # the device should have been registered already, but it may have been
@@ -316,8 +291,88 @@ class E2eKeysHandler(object):
         # old access_token without an associated device_id. Either way, we
         # need to double-check the device is registered to avoid ending up with
         # keys without a corresponding device.
-        self.device_handler.check_device_registered(user_id, device_id)
+        yield self.device_handler.check_device_registered(user_id, device_id)
 
         result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
 
         defer.returnValue({"one_time_key_counts": result})
+
+    @defer.inlineCallbacks
+    def _upload_one_time_keys_for_user(self, user_id, device_id, time_now,
+                                       one_time_keys):
+        logger.info(
+            "Adding one_time_keys %r for device %r for user %r at %d",
+            one_time_keys.keys(), device_id, user_id, time_now,
+        )
+
+        # make a list of (alg, id, key) tuples
+        key_list = []
+        for key_id, key_obj in one_time_keys.items():
+            algorithm, key_id = key_id.split(":")
+            key_list.append((
+                algorithm, key_id, key_obj
+            ))
+
+        # First we check if we have already persisted any of the keys.
+        existing_key_map = yield self.store.get_e2e_one_time_keys(
+            user_id, device_id, [k_id for _, k_id, _ in key_list]
+        )
+
+        new_keys = []  # Keys that we need to insert. (alg, id, json) tuples.
+        for algorithm, key_id, key in key_list:
+            ex_json = existing_key_map.get((algorithm, key_id), None)
+            if ex_json:
+                if not _one_time_keys_match(ex_json, key):
+                    raise SynapseError(
+                        400,
+                        ("One time key %s:%s already exists. "
+                         "Old key: %s; new key: %r") %
+                        (algorithm, key_id, ex_json, key)
+                    )
+            else:
+                new_keys.append((algorithm, key_id, encode_canonical_json(key)))
+
+        yield self.store.add_e2e_one_time_keys(
+            user_id, device_id, time_now, new_keys
+        )
+
+
+def _exception_to_failure(e):
+    if isinstance(e, CodeMessageException):
+        return {
+            "status": e.code, "message": e.message,
+        }
+
+    if isinstance(e, NotRetryingDestination):
+        return {
+            "status": 503, "message": "Not ready for retry",
+        }
+
+    if isinstance(e, FederationDeniedError):
+        return {
+            "status": 403, "message": "Federation Denied",
+        }
+
+    # include ConnectionRefused and other errors
+    #
+    # Note that some Exceptions (notably twisted's ResponseFailed etc) don't
+    # give a string for e.message, which simplejson then fails to serialize.
+    return {
+        "status": 503, "message": str(e.message),
+    }
+
+
+def _one_time_keys_match(old_key_json, new_key):
+    old_key = json.loads(old_key_json)
+
+    # if either is a string rather than an object, they must match exactly
+    if not isinstance(old_key, dict) or not isinstance(new_key, dict):
+        return old_key == new_key
+
+    # otherwise, we strip off the 'signatures' if any, because it's legitimate
+    # for different upload attempts to have different signatures.
+    old_key.pop("signatures", None)
+    new_key_copy = dict(new_key)
+    new_key_copy.pop("signatures", None)
+
+    return old_key == new_key_copy
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 996bfd0e23..f39233d846 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,24 +15,30 @@
 # limitations under the License.
 
 """Contains handlers for federation events."""
+
+import itertools
+import logging
+import sys
+
 from signedjson.key import decode_verify_key_bytes
 from signedjson.sign import verify_signed_json
+import six
+from six.moves import http_client
+from twisted.internet import defer
 from unpaddedbase64 import decode_base64
 
 from ._base import BaseHandler
 
 from synapse.api.errors import (
     AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
+    FederationDeniedError,
 )
 from synapse.api.constants import EventTypes, Membership, RejectedReason
 from synapse.events.validator import EventValidator
-from synapse.util import unwrapFirstError
-from synapse.util.logcontext import (
-    PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
-)
+from synapse.util import unwrapFirstError, logcontext
 from synapse.util.metrics import measure_func
 from synapse.util.logutils import log_function
-from synapse.util.async import run_on_reactor
+from synapse.util.async import run_on_reactor, Linearizer
 from synapse.util.frozenutils import unfreeze
 from synapse.crypto.event_signing import (
     compute_event_signature, add_hashes_and_signatures,
@@ -42,13 +49,8 @@ from synapse.events.utils import prune_event
 
 from synapse.util.retryutils import NotRetryingDestination
 
-from synapse.push.action_generator import ActionGenerator
 from synapse.util.distributor import user_joined_room
 
-from twisted.internet import defer
-
-import itertools
-import logging
 
 logger = logging.getLogger(__name__)
 
@@ -70,38 +72,268 @@ class FederationHandler(BaseHandler):
         self.hs = hs
 
         self.store = hs.get_datastore()
-        self.replication_layer = hs.get_replication_layer()
+        self.replication_layer = hs.get_federation_client()
         self.state_handler = hs.get_state_handler()
         self.server_name = hs.hostname
         self.keyring = hs.get_keyring()
-
-        self.replication_layer.set_handler(self)
+        self.action_generator = hs.get_action_generator()
+        self.is_mine_id = hs.is_mine_id
+        self.pusher_pool = hs.get_pusherpool()
+        self.spam_checker = hs.get_spam_checker()
+        self.event_creation_handler = hs.get_event_creation_handler()
 
         # When joining a room we need to queue any events for that room up
         self.room_queues = {}
+        self._room_pdu_linearizer = Linearizer("fed_room_pdu")
 
-    @log_function
     @defer.inlineCallbacks
-    def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
-        """ Called by the ReplicationLayer when we have a new pdu. We need to
-        do auth checks and put it through the StateHandler.
+    @log_function
+    def on_receive_pdu(self, origin, pdu, get_missing=True):
+        """ Process a PDU received via a federation /send/ transaction, or
+        via backfill of missing prev_events
+
+        Args:
+            origin (str): server which initiated the /send/ transaction. Will
+                be used to fetch missing events or state.
+            pdu (FrozenEvent): received PDU
+            get_missing (bool): True if we should fetch missing prev_events
 
-        auth_chain and state are None if we already have the necessary state
-        and prev_events in the db
+        Returns (Deferred): completes with None
         """
-        event = pdu
 
-        logger.debug("Got event: %s", event.event_id)
+        # We reprocess pdus when we have seen them only as outliers
+        existing = yield self.get_persisted_pdu(
+            origin, pdu.event_id, do_auth=False
+        )
+
+        # FIXME: Currently we fetch an event again when we already have it
+        # if it has been marked as an outlier.
+
+        already_seen = (
+            existing and (
+                not existing.internal_metadata.is_outlier()
+                or pdu.internal_metadata.is_outlier()
+            )
+        )
+        if already_seen:
+            logger.debug("Already seen pdu %s", pdu.event_id)
+            return
+
+        # do some initial sanity-checking of the event. In particular, make
+        # sure it doesn't have hundreds of prev_events or auth_events, which
+        # could cause a huge state resolution or cascade of event fetches.
+        try:
+            self._sanity_check_event(pdu)
+        except SynapseError as err:
+            raise FederationError(
+                "ERROR",
+                err.code,
+                err.msg,
+                affected=pdu.event_id,
+            )
 
         # If we are currently in the process of joining this room, then we
         # queue up events for later processing.
-        if event.room_id in self.room_queues:
-            self.room_queues[event.room_id].append((pdu, origin))
+        if pdu.room_id in self.room_queues:
+            logger.info("Ignoring PDU %s for room %s from %s for now; join "
+                        "in progress", pdu.event_id, pdu.room_id, origin)
+            self.room_queues[pdu.room_id].append((pdu, origin))
             return
 
-        logger.debug("Processing event: %s", event.event_id)
+        # If we're no longer in the room just ditch the event entirely. This
+        # is probably an old server that has come back and thinks we're still
+        # in the room (or we've been rejoined to the room by a state reset).
+        #
+        # If we were never in the room then maybe our database got vaped and
+        # we should check if we *are* in fact in the room. If we are then we
+        # can magically rejoin the room.
+        is_in_room = yield self.auth.check_host_in_room(
+            pdu.room_id,
+            self.server_name
+        )
+        if not is_in_room:
+            was_in_room = yield self.store.was_host_joined(
+                pdu.room_id, self.server_name,
+            )
+            if was_in_room:
+                logger.info(
+                    "Ignoring PDU %s for room %s from %s as we've left the room!",
+                    pdu.event_id, pdu.room_id, origin,
+                )
+                return
+
+        state = None
+
+        auth_chain = []
+
+        fetch_state = False
+
+        # Get missing pdus if necessary.
+        if not pdu.internal_metadata.is_outlier():
+            # We only backfill backwards to the min depth.
+            min_depth = yield self.get_min_depth_for_context(
+                pdu.room_id
+            )
+
+            logger.debug(
+                "_handle_new_pdu min_depth for %s: %d",
+                pdu.room_id, min_depth
+            )
+
+            prevs = {e_id for e_id, _ in pdu.prev_events}
+            seen = yield self.store.have_seen_events(prevs)
+
+            if min_depth and pdu.depth < min_depth:
+                # This is so that we don't notify the user about this
+                # message, to work around the fact that some events will
+                # reference really really old events we really don't want to
+                # send to the clients.
+                pdu.internal_metadata.outlier = True
+            elif min_depth and pdu.depth > min_depth:
+                if get_missing and prevs - seen:
+                    # If we're missing stuff, ensure we only fetch stuff one
+                    # at a time.
+                    logger.info(
+                        "Acquiring lock for room %r to fetch %d missing events: %r...",
+                        pdu.room_id, len(prevs - seen), list(prevs - seen)[:5],
+                    )
+                    with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
+                        logger.info(
+                            "Acquired lock for room %r to fetch %d missing events",
+                            pdu.room_id, len(prevs - seen),
+                        )
+
+                        yield self._get_missing_events_for_pdu(
+                            origin, pdu, prevs, min_depth
+                        )
+
+                        # Update the set of things we've seen after trying to
+                        # fetch the missing stuff
+                        seen = yield self.store.have_seen_events(prevs)
+
+                        if not prevs - seen:
+                            logger.info(
+                                "Found all missing prev events for %s", pdu.event_id
+                            )
+                elif prevs - seen:
+                    logger.info(
+                        "Not fetching %d missing events for room %r,event %s: %r...",
+                        len(prevs - seen), pdu.room_id, pdu.event_id,
+                        list(prevs - seen)[:5],
+                    )
+
+            if prevs - seen:
+                logger.info(
+                    "Still missing %d events for room %r: %r...",
+                    len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
+                )
+                fetch_state = True
 
-        logger.debug("Event: %s", event)
+        if fetch_state:
+            # We need to get the state at this event, since we haven't
+            # processed all the prev events.
+            logger.debug(
+                "_handle_new_pdu getting state for %s",
+                pdu.room_id
+            )
+            try:
+                state, auth_chain = yield self.replication_layer.get_state_for_room(
+                    origin, pdu.room_id, pdu.event_id,
+                )
+            except Exception:
+                logger.exception("Failed to get state for event: %s", pdu.event_id)
+
+        yield self._process_received_pdu(
+            origin,
+            pdu,
+            state=state,
+            auth_chain=auth_chain,
+        )
+
+    @defer.inlineCallbacks
+    def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+        """
+        Args:
+            origin (str): Origin of the pdu. Will be called to get the missing events
+            pdu: received pdu
+            prevs (set(str)): List of event ids which we are missing
+            min_depth (int): Minimum depth of events to return.
+        """
+        # We recalculate seen, since it may have changed.
+        seen = yield self.store.have_seen_events(prevs)
+
+        if not prevs - seen:
+            return
+
+        latest = yield self.store.get_latest_event_ids_in_room(
+            pdu.room_id
+        )
+
+        # We add the prev events that we have seen to the latest
+        # list to ensure the remote server doesn't give them to us
+        latest = set(latest)
+        latest |= seen
+
+        logger.info(
+            "Missing %d events for room %r pdu %s: %r...",
+            len(prevs - seen), pdu.room_id, pdu.event_id, list(prevs - seen)[:5]
+        )
+
+        # XXX: we set timeout to 10s to help workaround
+        # https://github.com/matrix-org/synapse/issues/1733.
+        # The reason is to avoid holding the linearizer lock
+        # whilst processing inbound /send transactions, causing
+        # FDs to stack up and block other inbound transactions
+        # which empirically can currently take up to 30 minutes.
+        #
+        # N.B. this explicitly disables retry attempts.
+        #
+        # N.B. this also increases our chances of falling back to
+        # fetching fresh state for the room if the missing event
+        # can't be found, which slightly reduces our security.
+        # it may also increase our DAG extremity count for the room,
+        # causing additional state resolution?  See #1760.
+        # However, fetching state doesn't hold the linearizer lock
+        # apparently.
+        #
+        # see https://github.com/matrix-org/synapse/pull/1744
+
+        missing_events = yield self.replication_layer.get_missing_events(
+            origin,
+            pdu.room_id,
+            earliest_events_ids=list(latest),
+            latest_events=[pdu],
+            limit=10,
+            min_depth=min_depth,
+            timeout=10000,
+        )
+
+        logger.info(
+            "Got %d events: %r...",
+            len(missing_events), [e.event_id for e in missing_events[:5]]
+        )
+
+        # We want to sort these by depth so we process them and
+        # tell clients about them in order.
+        missing_events.sort(key=lambda x: x.depth)
+
+        for e in missing_events:
+            logger.info("Handling found event %s", e.event_id)
+            yield self.on_receive_pdu(
+                origin,
+                e,
+                get_missing=False
+            )
+
+    @log_function
+    @defer.inlineCallbacks
+    def _process_received_pdu(self, origin, pdu, state, auth_chain):
+        """ Called when we have a new pdu. We need to do auth checks and put it
+        through the StateHandler.
+        """
+        event = pdu
+
+        logger.debug("Processing event: %s", event)
 
         # FIXME (erikj): Awful hack to make the case where we are not currently
         # in the room work
@@ -140,9 +372,7 @@ class FederationHandler(BaseHandler):
             if auth_chain:
                 event_ids |= {e.event_id for e in auth_chain}
 
-            seen_ids = set(
-                (yield self.store.have_events(event_ids)).keys()
-            )
+            seen_ids = yield self.store.have_seen_events(event_ids)
 
             if state and auth_chain is not None:
                 # If we have any state or auth_chain given to us by the replication
@@ -181,13 +411,6 @@ class FederationHandler(BaseHandler):
                     affected=event.event_id,
                 )
 
-        # if we're receiving valid events from an origin,
-        # it's probably a good idea to mark it as not in retry-state
-        # for sending (although this is a bit of a leap)
-        retry_timings = yield self.store.get_destination_retry_timings(origin)
-        if retry_timings and retry_timings["retry_last_ts"]:
-            self.store.set_destination_retry_timings(origin, 0, 0)
-
         room = yield self.store.get_room(event.room_id)
 
         if not room:
@@ -206,11 +429,10 @@ class FederationHandler(BaseHandler):
             target_user = UserID.from_string(target_user_id)
             extra_users.append(target_user)
 
-        with PreserveLoggingContext():
-            self.notifier.on_new_room_event(
-                event, event_stream_id, max_stream_id,
-                extra_users=extra_users
-            )
+        self.notifier.on_new_room_event(
+            event, event_stream_id, max_stream_id,
+            extra_users=extra_users
+        )
 
         if event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
@@ -249,7 +471,7 @@ class FederationHandler(BaseHandler):
         def check_match(id):
             try:
                 return server_name == get_domain_from_id(id)
-            except:
+            except Exception:
                 return False
 
         # Parses mapping `event_id -> (type, state_key) -> state event_id`
@@ -287,7 +509,7 @@ class FederationHandler(BaseHandler):
                             continue
                         try:
                             domain = get_domain_from_id(ev.state_key)
-                        except:
+                        except Exception:
                             continue
 
                         if domain != server_name:
@@ -314,9 +536,16 @@ class FederationHandler(BaseHandler):
     def backfill(self, dest, room_id, limit, extremities):
         """ Trigger a backfill request to `dest` for the given `room_id`
 
-        This will attempt to get more events from the remote. This may return
-        be successfull and still return no events if the other side has no new
-        events to offer.
+        This will attempt to get more events from the remote. If the other side
+        has no new events to offer, this will return an empty list.
+
+        As the events are received, we check their signatures, and also do some
+        sanity-checking on them. If any of the backfilled events are invalid,
+        this method throws a SynapseError.
+
+        TODO: make this more useful to distinguish failures of the remote
+        server from invalid events (there is probably no point in trying to
+        re-fetch invalid events from every other HS in the room.)
         """
         if dest == self.server_name:
             raise SynapseError(400, "Can't backfill from self.")
@@ -328,6 +557,16 @@ class FederationHandler(BaseHandler):
             extremities=extremities,
         )
 
+        # ideally we'd sanity check the events here for excess prev_events etc,
+        # but it's hard to reject events at this point without completely
+        # breaking backfill in the same way that it is currently broken by
+        # events whose signature we cannot verify (#3121).
+        #
+        # So for now we accept the events anyway. #3124 tracks this.
+        #
+        # for ev in events:
+        #     self._sanity_check_event(ev)
+
         # Don't bother processing events we already have.
         seen_events = yield self.store.have_events_in_timeline(
             set(e.event_id for e in events)
@@ -398,9 +637,10 @@ class FederationHandler(BaseHandler):
                     missing_auth - failed_to_fetch
                 )
 
-                results = yield preserve_context_over_deferred(defer.gatherResults(
+                results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
                     [
-                        preserve_fn(self.replication_layer.get_pdu)(
+                        logcontext.run_in_background(
+                            self.replication_layer.get_pdu,
                             [dest],
                             event_id,
                             outlier=True,
@@ -420,7 +660,7 @@ class FederationHandler(BaseHandler):
 
                 failed_to_fetch = missing_auth - set(auth_events)
 
-        seen_events = yield self.store.have_events(
+        seen_events = yield self.store.have_seen_events(
             set(auth_events.keys()) | set(state_events.keys())
         )
 
@@ -526,7 +766,7 @@ class FederationHandler(BaseHandler):
                         joined_domains[dom] = min(d, old_d)
                     else:
                         joined_domains[dom] = d
-                except:
+                except Exception:
                     pass
 
             return sorted(joined_domains.items(), key=lambda d: d[1])
@@ -570,6 +810,9 @@ class FederationHandler(BaseHandler):
                 except NotRetryingDestination as e:
                     logger.info(e.message)
                     continue
+                except FederationDeniedError as e:
+                    logger.info(e)
+                    continue
                 except Exception as e:
                     logger.exception(
                         "Failed to backfill from %s because %s",
@@ -592,10 +835,13 @@ class FederationHandler(BaseHandler):
         event_ids = list(extremities.keys())
 
         logger.debug("calling resolve_state_groups in _maybe_backfill")
-        states = yield preserve_context_over_deferred(defer.gatherResults([
-            preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
-            for e in event_ids
-        ]))
+        resolve = logcontext.preserve_fn(
+            self.state_handler.resolve_state_groups_for_events
+        )
+        states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
+            [resolve(room_id, [e]) for e in event_ids],
+            consumeErrors=True,
+        ))
         states = dict(zip(event_ids, [s.state for s in states]))
 
         state_map = yield self.store.get_events(
@@ -624,6 +870,38 @@ class FederationHandler(BaseHandler):
 
         defer.returnValue(False)
 
+    def _sanity_check_event(self, ev):
+        """
+        Do some early sanity checks of a received event
+
+        In particular, checks it doesn't have an excessive number of
+        prev_events or auth_events, which could cause a huge state resolution
+        or cascade of event fetches.
+
+        Args:
+            ev (synapse.events.EventBase): event to be checked
+
+        Returns: None
+
+        Raises:
+            SynapseError if the event does not pass muster
+        """
+        if len(ev.prev_events) > 20:
+            logger.warn("Rejecting event %s which has %i prev_events",
+                        ev.event_id, len(ev.prev_events))
+            raise SynapseError(
+                http_client.BAD_REQUEST,
+                "Too many prev_events",
+            )
+
+        if len(ev.auth_events) > 10:
+            logger.warn("Rejecting event %s which has %i auth_events",
+                        ev.event_id, len(ev.auth_events))
+            raise SynapseError(
+                http_client.BAD_REQUEST,
+                "Too many auth_events",
+            )
+
     @defer.inlineCallbacks
     def send_invite(self, target_host, event):
         """ Sends the invite to the remote server for signing.
@@ -641,7 +919,11 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def on_event_auth(self, event_id):
-        auth = yield self.store.get_auth_chain([event_id])
+        event = yield self.store.get_event(event_id)
+        auth = yield self.store.get_auth_chain(
+            [auth_id for auth_id, _ in event.auth_events],
+            include_given=True
+        )
 
         for event in auth:
             event.signatures.update(
@@ -670,8 +952,6 @@ class FederationHandler(BaseHandler):
         """
         logger.debug("Joining %s to %s", joinee, room_id)
 
-        yield self.store.clean_room_for_join(room_id)
-
         origin, event = yield self._make_and_verify_event(
             target_hosts,
             room_id,
@@ -680,7 +960,15 @@ class FederationHandler(BaseHandler):
             content,
         )
 
+        # This shouldn't happen, because the RoomMemberHandler has a
+        # linearizer lock which only allows one operation per user per room
+        # at a time - so this is just paranoia.
+        assert (room_id not in self.room_queues)
+
         self.room_queues[room_id] = []
+
+        yield self.store.clean_room_for_join(room_id)
+
         handled_events = set()
 
         try:
@@ -714,7 +1002,7 @@ class FederationHandler(BaseHandler):
                     room_creator_user_id="",
                     is_public=False
                 )
-            except:
+            except Exception:
                 # FIXME
                 pass
 
@@ -722,29 +1010,45 @@ class FederationHandler(BaseHandler):
                 origin, auth_chain, state, event
             )
 
-            with PreserveLoggingContext():
-                self.notifier.on_new_room_event(
-                    event, event_stream_id, max_stream_id,
-                    extra_users=[joinee]
-                )
+            self.notifier.on_new_room_event(
+                event, event_stream_id, max_stream_id,
+                extra_users=[joinee]
+            )
 
             logger.debug("Finished joining %s to %s", joinee, room_id)
         finally:
             room_queue = self.room_queues[room_id]
             del self.room_queues[room_id]
 
-            for p, origin in room_queue:
-                if p.event_id in handled_events:
-                    continue
+            # we don't need to wait for the queued events to be processed -
+            # it's just a best-effort thing at this point. We do want to do
+            # them roughly in order, though, otherwise we'll end up making
+            # lots of requests for missing prev_events which we do actually
+            # have. Hence we fire off the deferred, but don't wait for it.
 
-                try:
-                    self.on_receive_pdu(origin, p)
-                except:
-                    logger.exception("Couldn't handle pdu")
+            logcontext.run_in_background(self._handle_queued_pdus, room_queue)
 
         defer.returnValue(True)
 
     @defer.inlineCallbacks
+    def _handle_queued_pdus(self, room_queue):
+        """Process PDUs which got queued up while we were busy send_joining.
+
+        Args:
+            room_queue (list[FrozenEvent, str]): list of PDUs to be processed
+                and the servers that sent them
+        """
+        for p, origin in room_queue:
+            try:
+                logger.info("Processing queued PDU %s which was received "
+                            "while we were joining %s", p.event_id, p.room_id)
+                yield self.on_receive_pdu(origin, p)
+            except Exception as e:
+                logger.warn(
+                    "Error handling queued PDU %s from %s: %s",
+                    p.event_id, origin, e)
+
+    @defer.inlineCallbacks
     @log_function
     def on_make_join_request(self, room_id, user_id):
         """ We've received a /make_join/ request, so we create a partial
@@ -762,8 +1066,7 @@ class FederationHandler(BaseHandler):
         })
 
         try:
-            message_handler = self.hs.get_handlers().message_handler
-            event, context = yield message_handler._create_new_client_event(
+            event, context = yield self.event_creation_handler.create_new_client_event(
                 builder=builder,
             )
         except AuthError as e:
@@ -791,9 +1094,19 @@ class FederationHandler(BaseHandler):
         )
 
         event.internal_metadata.outlier = False
-        # Send this event on behalf of the origin server since they may not
-        # have an up to data view of the state of the room at this event so
-        # will not know which servers to send the event to.
+        # Send this event on behalf of the origin server.
+        #
+        # The reasons we have the destination server rather than the origin
+        # server send it are slightly mysterious: the origin server should have
+        # all the neccessary state once it gets the response to the send_join,
+        # so it could send the event itself if it wanted to. It may be that
+        # doing it this way reduces failure modes, or avoids certain attacks
+        # where a new server selectively tells a subset of the federation that
+        # it has joined.
+        #
+        # The fact is that, as of the current writing, Synapse doesn't send out
+        # the join event over federation after joining, and changing it now
+        # would introduce the danger of backwards-compatibility problems.
         event.internal_metadata.send_on_behalf_of = origin
 
         context, event_stream_id, max_stream_id = yield self._handle_new_event(
@@ -812,10 +1125,9 @@ class FederationHandler(BaseHandler):
             target_user = UserID.from_string(target_user_id)
             extra_users.append(target_user)
 
-        with PreserveLoggingContext():
-            self.notifier.on_new_room_event(
-                event, event_stream_id, max_stream_id, extra_users=extra_users
-            )
+        self.notifier.on_new_room_event(
+            event, event_stream_id, max_stream_id, extra_users=extra_users
+        )
 
         if event.type == EventTypes.Member:
             if event.content["membership"] == Membership.JOIN:
@@ -823,9 +1135,7 @@ class FederationHandler(BaseHandler):
                 yield user_joined_room(self.distributor, user, event.room_id)
 
         state_ids = context.prev_state_ids.values()
-        auth_chain = yield self.store.get_auth_chain(set(
-            [event.event_id] + state_ids
-        ))
+        auth_chain = yield self.store.get_auth_chain(state_ids)
 
         state = yield self.store.get_events(context.prev_state_ids.values())
 
@@ -842,6 +1152,34 @@ class FederationHandler(BaseHandler):
         """
         event = pdu
 
+        if event.state_key is None:
+            raise SynapseError(400, "The invite event did not have a state key")
+
+        is_blocked = yield self.store.is_room_blocked(event.room_id)
+        if is_blocked:
+            raise SynapseError(403, "This room has been blocked on this server")
+
+        if self.hs.config.block_non_admin_invites:
+            raise SynapseError(403, "This server does not accept room invites")
+
+        if not self.spam_checker.user_may_invite(
+            event.sender, event.state_key, event.room_id,
+        ):
+            raise SynapseError(
+                403, "This user is not permitted to send invites to this server/user"
+            )
+
+        membership = event.content.get("membership")
+        if event.type != EventTypes.Member or membership != Membership.INVITE:
+            raise SynapseError(400, "The event was not an m.room.member invite event")
+
+        sender_domain = get_domain_from_id(event.sender)
+        if sender_domain != origin:
+            raise SynapseError(400, "The invite event was not from the server sending it")
+
+        if not self.is_mine_id(event.state_key):
+            raise SynapseError(400, "The invite event must be for this server")
+
         event.internal_metadata.outlier = True
         event.internal_metadata.invite_from_remote = True
 
@@ -861,48 +1199,38 @@ class FederationHandler(BaseHandler):
         )
 
         target_user = UserID.from_string(event.state_key)
-        with PreserveLoggingContext():
-            self.notifier.on_new_room_event(
-                event, event_stream_id, max_stream_id,
-                extra_users=[target_user],
-            )
+        self.notifier.on_new_room_event(
+            event, event_stream_id, max_stream_id,
+            extra_users=[target_user],
+        )
 
         defer.returnValue(event)
 
     @defer.inlineCallbacks
     def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
-        try:
-            origin, event = yield self._make_and_verify_event(
-                target_hosts,
-                room_id,
-                user_id,
-                "leave"
-            )
-            signed_event = self._sign_event(event)
-        except SynapseError:
-            raise
-        except CodeMessageException as e:
-            logger.warn("Failed to reject invite: %s", e)
-            raise SynapseError(500, "Failed to reject invite")
-
-        # Try the host we successfully got a response to /make_join/
-        # request first.
+        origin, event = yield self._make_and_verify_event(
+            target_hosts,
+            room_id,
+            user_id,
+            "leave"
+        )
+        # Mark as outlier as we don't have any state for this event; we're not
+        # even in the room.
+        event.internal_metadata.outlier = True
+        event = self._sign_event(event)
+
+        # Try the host that we succesfully called /make_leave/ on first for
+        # the /send_leave/ request.
         try:
             target_hosts.remove(origin)
             target_hosts.insert(0, origin)
         except ValueError:
             pass
 
-        try:
-            yield self.replication_layer.send_leave(
-                target_hosts,
-                signed_event
-            )
-        except SynapseError:
-            raise
-        except CodeMessageException as e:
-            logger.warn("Failed to reject invite: %s", e)
-            raise SynapseError(500, "Failed to reject invite")
+        yield self.replication_layer.send_leave(
+            target_hosts,
+            event
+        )
 
         context = yield self.state_handler.compute_event_context(event)
 
@@ -978,8 +1306,7 @@ class FederationHandler(BaseHandler):
             "state_key": user_id,
         })
 
-        message_handler = self.hs.get_handlers().message_handler
-        event, context = yield message_handler._create_new_client_event(
+        event, context = yield self.event_creation_handler.create_new_client_event(
             builder=builder,
         )
 
@@ -1023,10 +1350,9 @@ class FederationHandler(BaseHandler):
             target_user = UserID.from_string(target_user_id)
             extra_users.append(target_user)
 
-        with PreserveLoggingContext():
-            self.notifier.on_new_room_event(
-                event, event_stream_id, max_stream_id, extra_users=extra_users
-            )
+        self.notifier.on_new_room_event(
+            event, event_stream_id, max_stream_id, extra_users=extra_users
+        )
 
         defer.returnValue(None)
 
@@ -1061,7 +1387,7 @@ class FederationHandler(BaseHandler):
             for event in res:
                 # We sign these again because there was a bug where we
                 # incorrectly signed things the first time round
-                if self.hs.is_mine_id(event.event_id):
+                if self.is_mine_id(event.event_id):
                     event.signatures.update(
                         compute_event_signature(
                             event,
@@ -1096,7 +1422,7 @@ class FederationHandler(BaseHandler):
                     if prev_id != event.event_id:
                         results[(event.type, event.state_key)] = prev_id
                 else:
-                    del results[(event.type, event.state_key)]
+                    results.pop((event.type, event.state_key), None)
 
             defer.returnValue(results.values())
         else:
@@ -1134,7 +1460,7 @@ class FederationHandler(BaseHandler):
         )
 
         if event:
-            if self.hs.is_mine_id(event.event_id):
+            if self.is_mine_id(event.event_id):
                 # FIXME: This is a temporary work around where we occasionally
                 # return events slightly differently than when they were
                 # originally signed
@@ -1178,23 +1504,33 @@ class FederationHandler(BaseHandler):
             auth_events=auth_events,
         )
 
-        if not event.internal_metadata.is_outlier():
-            action_generator = ActionGenerator(self.hs)
-            yield action_generator.handle_push_actions_for_event(
-                event, context
+        try:
+            if not event.internal_metadata.is_outlier() and not backfilled:
+                yield self.action_generator.handle_push_actions_for_event(
+                    event, context
+                )
+
+            event_stream_id, max_stream_id = yield self.store.persist_event(
+                event,
+                context=context,
+                backfilled=backfilled,
             )
+        except:  # noqa: E722, as we reraise the exception this is fine.
+            tp, value, tb = sys.exc_info()
 
-        event_stream_id, max_stream_id = yield self.store.persist_event(
-            event,
-            context=context,
-            backfilled=backfilled,
-        )
+            logcontext.run_in_background(
+                self.store.remove_push_actions_from_staging,
+                event.event_id,
+            )
+
+            six.reraise(tp, value, tb)
 
         if not backfilled:
             # this intentionally does not yield: we don't care about the result
             # and don't need to wait for it.
-            preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
-                event_stream_id, max_stream_id
+            logcontext.run_in_background(
+                self.pusher_pool.on_new_notifications,
+                event_stream_id, max_stream_id,
             )
 
         defer.returnValue((context, event_stream_id, max_stream_id))
@@ -1206,16 +1542,17 @@ class FederationHandler(BaseHandler):
         a bunch of outliers, but not a chunk of individual events that depend
         on each other for state calculations.
         """
-        contexts = yield preserve_context_over_deferred(defer.gatherResults(
+        contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
-                preserve_fn(self._prep_event)(
+                logcontext.run_in_background(
+                    self._prep_event,
                     origin,
                     ev_info["event"],
                     state=ev_info.get("state"),
                     auth_events=ev_info.get("auth_events"),
                 )
                 for ev_info in event_infos
-            ]
+            ], consumeErrors=True,
         ))
 
         yield self.store.persist_events(
@@ -1325,7 +1662,17 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def _prep_event(self, origin, event, state=None, auth_events=None):
+        """
+
+        Args:
+            origin:
+            event:
+            state:
+            auth_events:
 
+        Returns:
+            Deferred, which resolves to synapse.events.snapshot.EventContext
+        """
         context = yield self.state_handler.compute_event_context(
             event, old_state=state,
         )
@@ -1362,7 +1709,7 @@ class FederationHandler(BaseHandler):
 
             context.rejected = RejectedReason.AUTH_ERROR
 
-        if event.type == EventTypes.GuestAccess:
+        if event.type == EventTypes.GuestAccess and not context.rejected:
             yield self.maybe_kick_guest_users(event)
 
         defer.returnValue(context)
@@ -1379,7 +1726,11 @@ class FederationHandler(BaseHandler):
                 pass
 
         # Now get the current auth_chain for the event.
-        local_auth_chain = yield self.store.get_auth_chain([event_id])
+        event = yield self.store.get_event(event_id)
+        local_auth_chain = yield self.store.get_auth_chain(
+            [auth_id for auth_id, _ in event.auth_events],
+            include_given=True
+        )
 
         # TODO: Check if we would now reject event_id. If so we need to tell
         # everyone.
@@ -1427,6 +1778,17 @@ class FederationHandler(BaseHandler):
     @defer.inlineCallbacks
     @log_function
     def do_auth(self, origin, event, context, auth_events):
+        """
+
+        Args:
+            origin (str):
+            event (synapse.events.FrozenEvent):
+            context (synapse.events.snapshot.EventContext):
+            auth_events (dict[(str, str)->str]):
+
+        Returns:
+            defer.Deferred[None]
+        """
         # Check if we have all the auth events.
         current_state = set(e.event_id for e in auth_events.values())
         event_auth_events = set(e_id for e_id, _ in event.auth_events)
@@ -1437,7 +1799,8 @@ class FederationHandler(BaseHandler):
             event_key = None
 
         if event_auth_events - current_state:
-            have_events = yield self.store.have_events(
+            # TODO: can we use store.have_seen_events here instead?
+            have_events = yield self.store.get_seen_events_with_rejections(
                 event_auth_events - current_state
             )
         else:
@@ -1460,12 +1823,12 @@ class FederationHandler(BaseHandler):
                     origin, event.room_id, event.event_id
                 )
 
-                seen_remotes = yield self.store.have_events(
+                seen_remotes = yield self.store.have_seen_events(
                     [e.event_id for e in remote_auth_chain]
                 )
 
                 for e in remote_auth_chain:
-                    if e.event_id in seen_remotes.keys():
+                    if e.event_id in seen_remotes:
                         continue
 
                     if e.event_id == event.event_id:
@@ -1492,11 +1855,11 @@ class FederationHandler(BaseHandler):
                     except AuthError:
                         pass
 
-                have_events = yield self.store.have_events(
+                have_events = yield self.store.get_seen_events_with_rejections(
                     [e_id for e_id, _ in event.auth_events]
                 )
                 seen_events = set(have_events.keys())
-            except:
+            except Exception:
                 # FIXME:
                 logger.exception("Failed to get auth chain")
 
@@ -1509,18 +1872,18 @@ class FederationHandler(BaseHandler):
             # Do auth conflict res.
             logger.info("Different auth: %s", different_auth)
 
-            different_events = yield preserve_context_over_deferred(defer.gatherResults(
-                [
-                    preserve_fn(self.store.get_event)(
+            different_events = yield logcontext.make_deferred_yieldable(
+                defer.gatherResults([
+                    logcontext.run_in_background(
+                        self.store.get_event,
                         d,
                         allow_none=True,
                         allow_rejected=False,
                     )
                     for d in different_auth
                     if d in have_events and not have_events[d]
-                ],
-                consumeErrors=True
-            )).addErrback(unwrapFirstError)
+                ], consumeErrors=True)
+            ).addErrback(unwrapFirstError)
 
             if different_events:
                 local_view = dict(auth_events)
@@ -1539,16 +1902,9 @@ class FederationHandler(BaseHandler):
                 current_state = set(e.event_id for e in auth_events.values())
                 different_auth = event_auth_events - current_state
 
-                context.current_state_ids = dict(context.current_state_ids)
-                context.current_state_ids.update({
-                    k: a.event_id for k, a in auth_events.items()
-                    if k != event_key
-                })
-                context.prev_state_ids = dict(context.prev_state_ids)
-                context.prev_state_ids.update({
-                    k: a.event_id for k, a in auth_events.items()
-                })
-                context.state_group = self.store.get_next_state_group()
+                yield self._update_context_for_auth_events(
+                    event, context, auth_events, event_key,
+                )
 
         if different_auth and not event.internal_metadata.is_outlier():
             logger.info("Different auth after resolution: %s", different_auth)
@@ -1572,7 +1928,9 @@ class FederationHandler(BaseHandler):
                 auth_ids = yield self.auth.compute_auth_events(
                     event, context.prev_state_ids
                 )
-                local_auth_chain = yield self.store.get_auth_chain(auth_ids)
+                local_auth_chain = yield self.store.get_auth_chain(
+                    auth_ids, include_given=True
+                )
 
                 try:
                     # 2. Get remote difference.
@@ -1583,13 +1941,13 @@ class FederationHandler(BaseHandler):
                         local_auth_chain,
                     )
 
-                    seen_remotes = yield self.store.have_events(
+                    seen_remotes = yield self.store.have_seen_events(
                         [e.event_id for e in result["auth_chain"]]
                     )
 
                     # 3. Process any remote auth chain events we haven't seen.
                     for ev in result["auth_chain"]:
-                        if ev.event_id in seen_remotes.keys():
+                        if ev.event_id in seen_remotes:
                             continue
 
                         if ev.event_id == event.event_id:
@@ -1619,23 +1977,16 @@ class FederationHandler(BaseHandler):
                         except AuthError:
                             pass
 
-                except:
+                except Exception:
                     # FIXME:
                     logger.exception("Failed to query auth chain")
 
                 # 4. Look at rejects and their proofs.
                 # TODO.
 
-                context.current_state_ids = dict(context.current_state_ids)
-                context.current_state_ids.update({
-                    k: a.event_id for k, a in auth_events.items()
-                    if k != event_key
-                })
-                context.prev_state_ids = dict(context.prev_state_ids)
-                context.prev_state_ids.update({
-                    k: a.event_id for k, a in auth_events.items()
-                })
-                context.state_group = self.store.get_next_state_group()
+                yield self._update_context_for_auth_events(
+                    event, context, auth_events, event_key,
+                )
 
         try:
             self.auth.check(event, auth_events=auth_events)
@@ -1644,6 +1995,45 @@ class FederationHandler(BaseHandler):
             raise e
 
     @defer.inlineCallbacks
+    def _update_context_for_auth_events(self, event, context, auth_events,
+                                        event_key):
+        """Update the state_ids in an event context after auth event resolution,
+        storing the changes as a new state group.
+
+        Args:
+            event (Event): The event we're handling the context for
+
+            context (synapse.events.snapshot.EventContext): event context
+                to be updated
+
+            auth_events (dict[(str, str)->str]): Events to update in the event
+                context.
+
+            event_key ((str, str)): (type, state_key) for the current event.
+                this will not be included in the current_state in the context.
+        """
+        state_updates = {
+            k: a.event_id for k, a in auth_events.iteritems()
+            if k != event_key
+        }
+        context.current_state_ids = dict(context.current_state_ids)
+        context.current_state_ids.update(state_updates)
+        if context.delta_ids is not None:
+            context.delta_ids = dict(context.delta_ids)
+            context.delta_ids.update(state_updates)
+        context.prev_state_ids = dict(context.prev_state_ids)
+        context.prev_state_ids.update({
+            k: a.event_id for k, a in auth_events.iteritems()
+        })
+        context.state_group = yield self.store.store_state_group(
+            event.event_id,
+            event.room_id,
+            prev_group=context.prev_group,
+            delta_ids=context.delta_ids,
+            current_state_ids=context.current_state_ids,
+        )
+
+    @defer.inlineCallbacks
     def construct_auth_difference(self, local_auth, remote_auth):
         """ Given a local and remote auth chain, find the differences. This
         assumes that we have already processed all events in remote_auth
@@ -1686,7 +2076,7 @@ class FederationHandler(BaseHandler):
         def get_next(it, opt=None):
             try:
                 return it.next()
-            except:
+            except Exception:
                 return opt
 
         current_local = get_next(local_iter)
@@ -1811,8 +2201,7 @@ class FederationHandler(BaseHandler):
         if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
             builder = self.event_builder_factory.new(event_dict)
             EventValidator().validate_new(builder)
-            message_handler = self.hs.get_handlers().message_handler
-            event, context = yield message_handler._create_new_client_event(
+            event, context = yield self.event_creation_handler.create_new_client_event(
                 builder=builder
             )
 
@@ -1827,7 +2216,7 @@ class FederationHandler(BaseHandler):
                 raise e
 
             yield self._check_signature(event, context)
-            member_handler = self.hs.get_handlers().room_member_handler
+            member_handler = self.hs.get_room_member_handler()
             yield member_handler.send_membership_event(None, event, context)
         else:
             destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
@@ -1840,10 +2229,17 @@ class FederationHandler(BaseHandler):
     @defer.inlineCallbacks
     @log_function
     def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
+        """Handle an exchange_third_party_invite request from a remote server
+
+        The remote server will call this when it wants to turn a 3pid invite
+        into a normal m.room.member invite.
+
+        Returns:
+            Deferred: resolves (to None)
+        """
         builder = self.event_builder_factory.new(event_dict)
 
-        message_handler = self.hs.get_handlers().message_handler
-        event, context = yield message_handler._create_new_client_event(
+        event, context = yield self.event_creation_handler.create_new_client_event(
             builder=builder,
         )
 
@@ -1858,10 +2254,13 @@ class FederationHandler(BaseHandler):
             raise e
         yield self._check_signature(event, context)
 
+        # XXX we send the invite here, but send_membership_event also sends it,
+        # so we end up making two requests. I think this is redundant.
         returned_invite = yield self.send_invite(origin, event)
         # TODO: Make sure the signatures actually are correct.
         event.signatures.update(returned_invite.signatures)
-        member_handler = self.hs.get_handlers().room_member_handler
+
+        member_handler = self.hs.get_room_member_handler()
         yield member_handler.send_membership_event(None, event, context)
 
     @defer.inlineCallbacks
@@ -1890,8 +2289,9 @@ class FederationHandler(BaseHandler):
 
         builder = self.event_builder_factory.new(event_dict)
         EventValidator().validate_new(builder)
-        message_handler = self.hs.get_handlers().message_handler
-        event, context = yield message_handler._create_new_client_event(builder=builder)
+        event, context = yield self.event_creation_handler.create_new_client_event(
+            builder=builder,
+        )
         defer.returnValue((event, context))
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
new file mode 100644
index 0000000000..977993e7d4
--- /dev/null
+++ b/synapse/handlers/groups_local.py
@@ -0,0 +1,471 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.types import get_domain_from_id
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def _create_rerouter(func_name):
+    """Returns a function that looks at the group id and calls the function
+    on federation or the local group server if the group is local
+    """
+    def f(self, group_id, *args, **kwargs):
+        if self.is_mine_id(group_id):
+            return getattr(self.groups_server_handler, func_name)(
+                group_id, *args, **kwargs
+            )
+        else:
+            destination = get_domain_from_id(group_id)
+            return getattr(self.transport_client, func_name)(
+                destination, group_id, *args, **kwargs
+            )
+    return f
+
+
+class GroupsLocalHandler(object):
+    def __init__(self, hs):
+        self.hs = hs
+        self.store = hs.get_datastore()
+        self.room_list_handler = hs.get_room_list_handler()
+        self.groups_server_handler = hs.get_groups_server_handler()
+        self.transport_client = hs.get_federation_transport_client()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.keyring = hs.get_keyring()
+        self.is_mine_id = hs.is_mine_id
+        self.signing_key = hs.config.signing_key[0]
+        self.server_name = hs.hostname
+        self.notifier = hs.get_notifier()
+        self.attestations = hs.get_groups_attestation_signing()
+
+        self.profile_handler = hs.get_profile_handler()
+
+        # Ensure attestations get renewed
+        hs.get_groups_attestation_renewer()
+
+    # The following functions merely route the query to the local groups server
+    # or federation depending on if the group is local or remote
+
+    get_group_profile = _create_rerouter("get_group_profile")
+    update_group_profile = _create_rerouter("update_group_profile")
+    get_rooms_in_group = _create_rerouter("get_rooms_in_group")
+
+    get_invited_users_in_group = _create_rerouter("get_invited_users_in_group")
+
+    add_room_to_group = _create_rerouter("add_room_to_group")
+    update_room_in_group = _create_rerouter("update_room_in_group")
+    remove_room_from_group = _create_rerouter("remove_room_from_group")
+
+    update_group_summary_room = _create_rerouter("update_group_summary_room")
+    delete_group_summary_room = _create_rerouter("delete_group_summary_room")
+
+    update_group_category = _create_rerouter("update_group_category")
+    delete_group_category = _create_rerouter("delete_group_category")
+    get_group_category = _create_rerouter("get_group_category")
+    get_group_categories = _create_rerouter("get_group_categories")
+
+    update_group_summary_user = _create_rerouter("update_group_summary_user")
+    delete_group_summary_user = _create_rerouter("delete_group_summary_user")
+
+    update_group_role = _create_rerouter("update_group_role")
+    delete_group_role = _create_rerouter("delete_group_role")
+    get_group_role = _create_rerouter("get_group_role")
+    get_group_roles = _create_rerouter("get_group_roles")
+
+    set_group_join_policy = _create_rerouter("set_group_join_policy")
+
+    @defer.inlineCallbacks
+    def get_group_summary(self, group_id, requester_user_id):
+        """Get the group summary for a group.
+
+        If the group is remote we check that the users have valid attestations.
+        """
+        if self.is_mine_id(group_id):
+            res = yield self.groups_server_handler.get_group_summary(
+                group_id, requester_user_id
+            )
+        else:
+            res = yield self.transport_client.get_group_summary(
+                get_domain_from_id(group_id), group_id, requester_user_id,
+            )
+
+            group_server_name = get_domain_from_id(group_id)
+
+            # Loop through the users and validate the attestations.
+            chunk = res["users_section"]["users"]
+            valid_users = []
+            for entry in chunk:
+                g_user_id = entry["user_id"]
+                attestation = entry.pop("attestation", {})
+                try:
+                    if get_domain_from_id(g_user_id) != group_server_name:
+                        yield self.attestations.verify_attestation(
+                            attestation,
+                            group_id=group_id,
+                            user_id=g_user_id,
+                            server_name=get_domain_from_id(g_user_id),
+                        )
+                    valid_users.append(entry)
+                except Exception as e:
+                    logger.info("Failed to verify user is in group: %s", e)
+
+            res["users_section"]["users"] = valid_users
+
+            res["users_section"]["users"].sort(key=lambda e: e.get("order", 0))
+            res["rooms_section"]["rooms"].sort(key=lambda e: e.get("order", 0))
+
+        # Add `is_publicised` flag to indicate whether the user has publicised their
+        # membership of the group on their profile
+        result = yield self.store.get_publicised_groups_for_user(requester_user_id)
+        is_publicised = group_id in result
+
+        res.setdefault("user", {})["is_publicised"] = is_publicised
+
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def create_group(self, group_id, user_id, content):
+        """Create a group
+        """
+
+        logger.info("Asking to create group with ID: %r", group_id)
+
+        if self.is_mine_id(group_id):
+            res = yield self.groups_server_handler.create_group(
+                group_id, user_id, content
+            )
+            local_attestation = None
+            remote_attestation = None
+        else:
+            local_attestation = self.attestations.create_attestation(group_id, user_id)
+            content["attestation"] = local_attestation
+
+            content["user_profile"] = yield self.profile_handler.get_profile(user_id)
+
+            res = yield self.transport_client.create_group(
+                get_domain_from_id(group_id), group_id, user_id, content,
+            )
+
+            remote_attestation = res["attestation"]
+            yield self.attestations.verify_attestation(
+                remote_attestation,
+                group_id=group_id,
+                user_id=user_id,
+                server_name=get_domain_from_id(group_id),
+            )
+
+        is_publicised = content.get("publicise", False)
+        token = yield self.store.register_user_group_membership(
+            group_id, user_id,
+            membership="join",
+            is_admin=True,
+            local_attestation=local_attestation,
+            remote_attestation=remote_attestation,
+            is_publicised=is_publicised,
+        )
+        self.notifier.on_new_event(
+            "groups_key", token, users=[user_id],
+        )
+
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def get_users_in_group(self, group_id, requester_user_id):
+        """Get users in a group
+        """
+        if self.is_mine_id(group_id):
+            res = yield self.groups_server_handler.get_users_in_group(
+                group_id, requester_user_id
+            )
+            defer.returnValue(res)
+
+        group_server_name = get_domain_from_id(group_id)
+
+        res = yield self.transport_client.get_users_in_group(
+            get_domain_from_id(group_id), group_id, requester_user_id,
+        )
+
+        chunk = res["chunk"]
+        valid_entries = []
+        for entry in chunk:
+            g_user_id = entry["user_id"]
+            attestation = entry.pop("attestation", {})
+            try:
+                if get_domain_from_id(g_user_id) != group_server_name:
+                    yield self.attestations.verify_attestation(
+                        attestation,
+                        group_id=group_id,
+                        user_id=g_user_id,
+                        server_name=get_domain_from_id(g_user_id),
+                    )
+                valid_entries.append(entry)
+            except Exception as e:
+                logger.info("Failed to verify user is in group: %s", e)
+
+        res["chunk"] = valid_entries
+
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def join_group(self, group_id, user_id, content):
+        """Request to join a group
+        """
+        if self.is_mine_id(group_id):
+            yield self.groups_server_handler.join_group(
+                group_id, user_id, content
+            )
+            local_attestation = None
+            remote_attestation = None
+        else:
+            local_attestation = self.attestations.create_attestation(group_id, user_id)
+            content["attestation"] = local_attestation
+
+            res = yield self.transport_client.join_group(
+                get_domain_from_id(group_id), group_id, user_id, content,
+            )
+
+            remote_attestation = res["attestation"]
+
+            yield self.attestations.verify_attestation(
+                remote_attestation,
+                group_id=group_id,
+                user_id=user_id,
+                server_name=get_domain_from_id(group_id),
+            )
+
+        # TODO: Check that the group is public and we're being added publically
+        is_publicised = content.get("publicise", False)
+
+        token = yield self.store.register_user_group_membership(
+            group_id, user_id,
+            membership="join",
+            is_admin=False,
+            local_attestation=local_attestation,
+            remote_attestation=remote_attestation,
+            is_publicised=is_publicised,
+        )
+        self.notifier.on_new_event(
+            "groups_key", token, users=[user_id],
+        )
+
+        defer.returnValue({})
+
+    @defer.inlineCallbacks
+    def accept_invite(self, group_id, user_id, content):
+        """Accept an invite to a group
+        """
+        if self.is_mine_id(group_id):
+            yield self.groups_server_handler.accept_invite(
+                group_id, user_id, content
+            )
+            local_attestation = None
+            remote_attestation = None
+        else:
+            local_attestation = self.attestations.create_attestation(group_id, user_id)
+            content["attestation"] = local_attestation
+
+            res = yield self.transport_client.accept_group_invite(
+                get_domain_from_id(group_id), group_id, user_id, content,
+            )
+
+            remote_attestation = res["attestation"]
+
+            yield self.attestations.verify_attestation(
+                remote_attestation,
+                group_id=group_id,
+                user_id=user_id,
+                server_name=get_domain_from_id(group_id),
+            )
+
+        # TODO: Check that the group is public and we're being added publically
+        is_publicised = content.get("publicise", False)
+
+        token = yield self.store.register_user_group_membership(
+            group_id, user_id,
+            membership="join",
+            is_admin=False,
+            local_attestation=local_attestation,
+            remote_attestation=remote_attestation,
+            is_publicised=is_publicised,
+        )
+        self.notifier.on_new_event(
+            "groups_key", token, users=[user_id],
+        )
+
+        defer.returnValue({})
+
+    @defer.inlineCallbacks
+    def invite(self, group_id, user_id, requester_user_id, config):
+        """Invite a user to a group
+        """
+        content = {
+            "requester_user_id": requester_user_id,
+            "config": config,
+        }
+        if self.is_mine_id(group_id):
+            res = yield self.groups_server_handler.invite_to_group(
+                group_id, user_id, requester_user_id, content,
+            )
+        else:
+            res = yield self.transport_client.invite_to_group(
+                get_domain_from_id(group_id), group_id, user_id, requester_user_id,
+                content,
+            )
+
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def on_invite(self, group_id, user_id, content):
+        """One of our users were invited to a group
+        """
+        # TODO: Support auto join and rejection
+
+        if not self.is_mine_id(user_id):
+            raise SynapseError(400, "User not on this server")
+
+        local_profile = {}
+        if "profile" in content:
+            if "name" in content["profile"]:
+                local_profile["name"] = content["profile"]["name"]
+            if "avatar_url" in content["profile"]:
+                local_profile["avatar_url"] = content["profile"]["avatar_url"]
+
+        token = yield self.store.register_user_group_membership(
+            group_id, user_id,
+            membership="invite",
+            content={"profile": local_profile, "inviter": content["inviter"]},
+        )
+        self.notifier.on_new_event(
+            "groups_key", token, users=[user_id],
+        )
+        try:
+            user_profile = yield self.profile_handler.get_profile(user_id)
+        except Exception as e:
+            logger.warn("No profile for user %s: %s", user_id, e)
+            user_profile = {}
+
+        defer.returnValue({"state": "invite", "user_profile": user_profile})
+
+    @defer.inlineCallbacks
+    def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
+        """Remove a user from a group
+        """
+        if user_id == requester_user_id:
+            token = yield self.store.register_user_group_membership(
+                group_id, user_id,
+                membership="leave",
+            )
+            self.notifier.on_new_event(
+                "groups_key", token, users=[user_id],
+            )
+
+            # TODO: Should probably remember that we tried to leave so that we can
+            # retry if the group server is currently down.
+
+        if self.is_mine_id(group_id):
+            res = yield self.groups_server_handler.remove_user_from_group(
+                group_id, user_id, requester_user_id, content,
+            )
+        else:
+            content["requester_user_id"] = requester_user_id
+            res = yield self.transport_client.remove_user_from_group(
+                get_domain_from_id(group_id), group_id, requester_user_id,
+                user_id, content,
+            )
+
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def user_removed_from_group(self, group_id, user_id, content):
+        """One of our users was removed/kicked from a group
+        """
+        # TODO: Check if user in group
+        token = yield self.store.register_user_group_membership(
+            group_id, user_id,
+            membership="leave",
+        )
+        self.notifier.on_new_event(
+            "groups_key", token, users=[user_id],
+        )
+
+    @defer.inlineCallbacks
+    def get_joined_groups(self, user_id):
+        group_ids = yield self.store.get_joined_groups(user_id)
+        defer.returnValue({"groups": group_ids})
+
+    @defer.inlineCallbacks
+    def get_publicised_groups_for_user(self, user_id):
+        if self.hs.is_mine_id(user_id):
+            result = yield self.store.get_publicised_groups_for_user(user_id)
+
+            # Check AS associated groups for this user - this depends on the
+            # RegExps in the AS registration file (under `users`)
+            for app_service in self.store.get_app_services():
+                result.extend(app_service.get_groups_for_user(user_id))
+
+            defer.returnValue({"groups": result})
+        else:
+            bulk_result = yield self.transport_client.bulk_get_publicised_groups(
+                get_domain_from_id(user_id), [user_id],
+            )
+            result = bulk_result.get("users", {}).get(user_id)
+            # TODO: Verify attestations
+            defer.returnValue({"groups": result})
+
+    @defer.inlineCallbacks
+    def bulk_get_publicised_groups(self, user_ids, proxy=True):
+        destinations = {}
+        local_users = set()
+
+        for user_id in user_ids:
+            if self.hs.is_mine_id(user_id):
+                local_users.add(user_id)
+            else:
+                destinations.setdefault(
+                    get_domain_from_id(user_id), set()
+                ).add(user_id)
+
+        if not proxy and destinations:
+            raise SynapseError(400, "Some user_ids are not local")
+
+        results = {}
+        failed_results = []
+        for destination, dest_user_ids in destinations.iteritems():
+            try:
+                r = yield self.transport_client.bulk_get_publicised_groups(
+                    destination, list(dest_user_ids),
+                )
+                results.update(r["users"])
+            except Exception:
+                failed_results.extend(dest_user_ids)
+
+        for uid in local_users:
+            results[uid] = yield self.store.get_publicised_groups_for_user(
+                uid
+            )
+
+            # Check AS associated groups for this user - this depends on the
+            # RegExps in the AS registration file (under `users`)
+            for app_service in self.store.get_app_services():
+                results[uid].extend(app_service.get_groups_for_user(uid))
+
+        defer.returnValue({"users": results})
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 559e5d5a71..91a0898860 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,18 +15,20 @@
 # limitations under the License.
 
 """Utilities for interacting with Identity Servers"""
+
+import logging
+
+import simplejson as json
+
 from twisted.internet import defer
 
 from synapse.api.errors import (
-    CodeMessageException
+    MatrixCodeMessageException, CodeMessageException
 )
 from ._base import BaseHandler
 from synapse.util.async import run_on_reactor
 from synapse.api.errors import SynapseError, Codes
 
-import json
-import logging
-
 logger = logging.getLogger(__name__)
 
 
@@ -89,6 +92,9 @@ class IdentityHandler(BaseHandler):
                 ),
                 {'sid': creds['sid'], 'client_secret': client_secret}
             )
+        except MatrixCodeMessageException as e:
+            logger.info("getValidated3pid failed with Matrix error: %r", e)
+            raise SynapseError(e.code, e.msg, e.errcode)
         except CodeMessageException as e:
             data = json.loads(e.msg)
 
@@ -150,7 +156,7 @@ class IdentityHandler(BaseHandler):
         params.update(kwargs)
 
         try:
-            data = yield self.http_client.post_urlencoded_get_json(
+            data = yield self.http_client.post_json_get_json(
                 "https://%s%s" % (
                     id_server,
                     "/_matrix/identity/api/v1/validate/email/requestToken"
@@ -158,6 +164,46 @@ class IdentityHandler(BaseHandler):
                 params
             )
             defer.returnValue(data)
+        except MatrixCodeMessageException as e:
+            logger.info("Proxied requestToken failed with Matrix error: %r", e)
+            raise SynapseError(e.code, e.msg, e.errcode)
+        except CodeMessageException as e:
+            logger.info("Proxied requestToken failed: %r", e)
+            raise e
+
+    @defer.inlineCallbacks
+    def requestMsisdnToken(
+            self, id_server, country, phone_number,
+            client_secret, send_attempt, **kwargs
+    ):
+        yield run_on_reactor()
+
+        if not self._should_trust_id_server(id_server):
+            raise SynapseError(
+                400, "Untrusted ID server '%s'" % id_server,
+                Codes.SERVER_NOT_TRUSTED
+            )
+
+        params = {
+            'country': country,
+            'phone_number': phone_number,
+            'client_secret': client_secret,
+            'send_attempt': send_attempt,
+        }
+        params.update(kwargs)
+
+        try:
+            data = yield self.http_client.post_json_get_json(
+                "https://%s%s" % (
+                    id_server,
+                    "/_matrix/identity/api/v1/validate/msisdn/requestToken"
+                ),
+                params
+            )
+            defer.returnValue(data)
+        except MatrixCodeMessageException as e:
+            logger.info("Proxied requestToken failed with Matrix error: %r", e)
+            raise SynapseError(e.code, e.msg, e.errcode)
         except CodeMessageException as e:
             logger.info("Proxied requestToken failed: %r", e)
             raise e
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index e0ade4c164..cd33a86599 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -19,6 +19,7 @@ from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import AuthError, Codes
 from synapse.events.utils import serialize_event
 from synapse.events.validator import EventValidator
+from synapse.handlers.presence import format_user_presence_state
 from synapse.streams.config import PaginationConfig
 from synapse.types import (
     UserID, StreamToken,
@@ -26,7 +27,7 @@ from synapse.types import (
 from synapse.util import unwrapFirstError
 from synapse.util.async import concurrently_execute
 from synapse.util.caches.snapshot_cache import SnapshotCache
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
@@ -162,10 +163,11 @@ class InitialSyncHandler(BaseHandler):
                         lambda states: states[event.event_id]
                     )
 
-                (messages, token), current_state = yield preserve_context_over_deferred(
+                (messages, token), current_state = yield make_deferred_yieldable(
                     defer.gatherResults(
                         [
-                            preserve_fn(self.store.get_recent_events_for_room)(
+                            run_in_background(
+                                self.store.get_recent_events_for_room,
                                 event.room_id,
                                 limit=limit,
                                 end_token=room_end_token,
@@ -213,7 +215,7 @@ class InitialSyncHandler(BaseHandler):
                     })
 
                 d["account_data"] = account_data_events
-            except:
+            except Exception:
                 logger.exception("Failed to get snapshot")
 
         yield concurrently_execute(handle_room, room_list, 10)
@@ -225,9 +227,17 @@ class InitialSyncHandler(BaseHandler):
                 "content": content,
             })
 
+        now = self.clock.time_msec()
+
         ret = {
             "rooms": rooms_ret,
-            "presence": presence,
+            "presence": [
+                {
+                    "type": "m.presence",
+                    "content": format_user_presence_state(event, now),
+                }
+                for event in presence
+            ],
             "account_data": account_data_events,
             "receipts": receipt,
             "end": now_token.to_string(),
@@ -382,9 +392,10 @@ class InitialSyncHandler(BaseHandler):
 
         presence, receipts, (messages, token) = yield defer.gatherResults(
             [
-                preserve_fn(get_presence)(),
-                preserve_fn(get_receipts)(),
-                preserve_fn(self.store.get_recent_events_for_room)(
+                run_in_background(get_presence),
+                run_in_background(get_receipts),
+                run_in_background(
+                    self.store.get_recent_events_for_room,
                     room_id,
                     limit=limit,
                     end_token=now_token.room_key,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 7a498af5a2..b793fc4df7 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2017 - 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,31 +13,64 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
+import simplejson
+import sys
 
-from twisted.internet import defer
+from canonicaljson import encode_canonical_json
+import six
+from twisted.internet import defer, reactor
+from twisted.python.failure import Failure
 
-from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, SynapseError, LimitExceededError
+from synapse.api.constants import EventTypes, Membership, MAX_DEPTH
+from synapse.api.errors import AuthError, Codes, SynapseError
 from synapse.crypto.event_signing import add_hashes_and_signatures
 from synapse.events.utils import serialize_event
 from synapse.events.validator import EventValidator
-from synapse.push.action_generator import ActionGenerator
 from synapse.types import (
     UserID, RoomAlias, RoomStreamToken,
 )
 from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import run_in_background
 from synapse.util.metrics import measure_func
+from synapse.util.frozenutils import frozendict_json_encoder
+from synapse.util.stringutils import random_string
 from synapse.visibility import filter_events_for_client
+from synapse.replication.http.send_event import send_event_to_master
 
 from ._base import BaseHandler
 
-from canonicaljson import encode_canonical_json
+logger = logging.getLogger(__name__)
 
-import logging
-import random
 
-logger = logging.getLogger(__name__)
+class PurgeStatus(object):
+    """Object tracking the status of a purge request
+
+    This class contains information on the progress of a purge request, for
+    return by get_purge_status.
+
+    Attributes:
+        status (int): Tracks whether this request has completed. One of
+            STATUS_{ACTIVE,COMPLETE,FAILED}
+    """
+
+    STATUS_ACTIVE = 0
+    STATUS_COMPLETE = 1
+    STATUS_FAILED = 2
+
+    STATUS_TEXT = {
+        STATUS_ACTIVE: "active",
+        STATUS_COMPLETE: "complete",
+        STATUS_FAILED: "failed",
+    }
+
+    def __init__(self):
+        self.status = PurgeStatus.STATUS_ACTIVE
+
+    def asdict(self):
+        return {
+            "status": PurgeStatus.STATUS_TEXT[self.status]
+        }
 
 
 class MessageHandler(BaseHandler):
@@ -46,25 +80,89 @@ class MessageHandler(BaseHandler):
         self.hs = hs
         self.state = hs.get_state_handler()
         self.clock = hs.get_clock()
-        self.validator = EventValidator()
 
         self.pagination_lock = ReadWriteLock()
+        self._purges_in_progress_by_room = set()
+        # map from purge id to PurgeStatus
+        self._purges_by_id = {}
 
-        # We arbitrarily limit concurrent event creation for a room to 5.
-        # This is to stop us from diverging history *too* much.
-        self.limiter = Limiter(max_count=5)
+    def start_purge_history(self, room_id, topological_ordering,
+                            delete_local_events=False):
+        """Start off a history purge on a room.
+
+        Args:
+            room_id (str): The room to purge from
+
+            topological_ordering (int): minimum topo ordering to preserve
+            delete_local_events (bool): True to delete local events as well as
+                remote ones
+
+        Returns:
+            str: unique ID for this purge transaction.
+        """
+        if room_id in self._purges_in_progress_by_room:
+            raise SynapseError(
+                400,
+                "History purge already in progress for %s" % (room_id, ),
+            )
+
+        purge_id = random_string(16)
+
+        # we log the purge_id here so that it can be tied back to the
+        # request id in the log lines.
+        logger.info("[purge] starting purge_id %s", purge_id)
+
+        self._purges_by_id[purge_id] = PurgeStatus()
+        run_in_background(
+            self._purge_history,
+            purge_id, room_id, topological_ordering, delete_local_events,
+        )
+        return purge_id
 
     @defer.inlineCallbacks
-    def purge_history(self, room_id, event_id):
-        event = yield self.store.get_event(event_id)
+    def _purge_history(self, purge_id, room_id, topological_ordering,
+                       delete_local_events):
+        """Carry out a history purge on a room.
 
-        if event.room_id != room_id:
-            raise SynapseError(400, "Event is for wrong room.")
+        Args:
+            purge_id (str): The id for this purge
+            room_id (str): The room to purge from
+            topological_ordering (int): minimum topo ordering to preserve
+            delete_local_events (bool): True to delete local events as well as
+                remote ones
 
-        depth = event.depth
+        Returns:
+            Deferred
+        """
+        self._purges_in_progress_by_room.add(room_id)
+        try:
+            with (yield self.pagination_lock.write(room_id)):
+                yield self.store.purge_history(
+                    room_id, topological_ordering, delete_local_events,
+                )
+            logger.info("[purge] complete")
+            self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE
+        except Exception:
+            logger.error("[purge] failed: %s", Failure().getTraceback().rstrip())
+            self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
+        finally:
+            self._purges_in_progress_by_room.discard(room_id)
+
+            # remove the purge from the list 24 hours after it completes
+            def clear_purge():
+                del self._purges_by_id[purge_id]
+            reactor.callLater(24 * 3600, clear_purge)
+
+    def get_purge_status(self, purge_id):
+        """Get the current status of an active purge
 
-        with (yield self.pagination_lock.write(room_id)):
-            yield self.store.delete_old_state(room_id, depth)
+        Args:
+            purge_id (str): purge_id returned by start_purge_history
+
+        Returns:
+            PurgeStatus|None
+        """
+        return self._purges_by_id.get(purge_id)
 
     @defer.inlineCallbacks
     def get_messages(self, requester, room_id=None, pagin_config=None,
@@ -175,7 +273,167 @@ class MessageHandler(BaseHandler):
         defer.returnValue(chunk)
 
     @defer.inlineCallbacks
-    def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None):
+    def get_room_data(self, user_id=None, room_id=None,
+                      event_type=None, state_key="", is_guest=False):
+        """ Get data from a room.
+
+        Args:
+            event : The room path event
+        Returns:
+            The path data content.
+        Raises:
+            SynapseError if something went wrong.
+        """
+        membership, membership_event_id = yield self._check_in_room_or_world_readable(
+            room_id, user_id
+        )
+
+        if membership == Membership.JOIN:
+            data = yield self.state_handler.get_current_state(
+                room_id, event_type, state_key
+            )
+        elif membership == Membership.LEAVE:
+            key = (event_type, state_key)
+            room_state = yield self.store.get_state_for_events(
+                [membership_event_id], [key]
+            )
+            data = room_state[membership_event_id].get(key)
+
+        defer.returnValue(data)
+
+    @defer.inlineCallbacks
+    def _check_in_room_or_world_readable(self, room_id, user_id):
+        try:
+            # check_user_was_in_room will return the most recent membership
+            # event for the user if:
+            #  * The user is a non-guest user, and was ever in the room
+            #  * The user is a guest user, and has joined the room
+            # else it will throw.
+            member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
+            defer.returnValue((member_event.membership, member_event.event_id))
+            return
+        except AuthError:
+            visibility = yield self.state_handler.get_current_state(
+                room_id, EventTypes.RoomHistoryVisibility, ""
+            )
+            if (
+                visibility and
+                visibility.content["history_visibility"] == "world_readable"
+            ):
+                defer.returnValue((Membership.JOIN, None))
+                return
+            raise AuthError(
+                403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
+            )
+
+    @defer.inlineCallbacks
+    def get_state_events(self, user_id, room_id, is_guest=False):
+        """Retrieve all state events for a given room. If the user is
+        joined to the room then return the current state. If the user has
+        left the room return the state events from when they left.
+
+        Args:
+            user_id(str): The user requesting state events.
+            room_id(str): The room ID to get all state events from.
+        Returns:
+            A list of dicts representing state events. [{}, {}, {}]
+        """
+        membership, membership_event_id = yield self._check_in_room_or_world_readable(
+            room_id, user_id
+        )
+
+        if membership == Membership.JOIN:
+            room_state = yield self.state_handler.get_current_state(room_id)
+        elif membership == Membership.LEAVE:
+            room_state = yield self.store.get_state_for_events(
+                [membership_event_id], None
+            )
+            room_state = room_state[membership_event_id]
+
+        now = self.clock.time_msec()
+        defer.returnValue(
+            [serialize_event(c, now) for c in room_state.values()]
+        )
+
+    @defer.inlineCallbacks
+    def get_joined_members(self, requester, room_id):
+        """Get all the joined members in the room and their profile information.
+
+        If the user has left the room return the state events from when they left.
+
+        Args:
+            requester(Requester): The user requesting state events.
+            room_id(str): The room ID to get all state events from.
+        Returns:
+            A dict of user_id to profile info
+        """
+        user_id = requester.user.to_string()
+        if not requester.app_service:
+            # We check AS auth after fetching the room membership, as it
+            # requires us to pull out all joined members anyway.
+            membership, _ = yield self._check_in_room_or_world_readable(
+                room_id, user_id
+            )
+            if membership != Membership.JOIN:
+                raise NotImplementedError(
+                    "Getting joined members after leaving is not implemented"
+                )
+
+        users_with_profile = yield self.state.get_current_user_in_room(room_id)
+
+        # If this is an AS, double check that they are allowed to see the members.
+        # This can either be because the AS user is in the room or becuase there
+        # is a user in the room that the AS is "interested in"
+        if requester.app_service and user_id not in users_with_profile:
+            for uid in users_with_profile:
+                if requester.app_service.is_interested_in_user(uid):
+                    break
+            else:
+                # Loop fell through, AS has no interested users in room
+                raise AuthError(403, "Appservice not in room")
+
+        defer.returnValue({
+            user_id: {
+                "avatar_url": profile.avatar_url,
+                "display_name": profile.display_name,
+            }
+            for user_id, profile in users_with_profile.iteritems()
+        })
+
+
+class EventCreationHandler(object):
+    def __init__(self, hs):
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+        self.state = hs.get_state_handler()
+        self.clock = hs.get_clock()
+        self.validator = EventValidator()
+        self.profile_handler = hs.get_profile_handler()
+        self.event_builder_factory = hs.get_event_builder_factory()
+        self.server_name = hs.hostname
+        self.ratelimiter = hs.get_ratelimiter()
+        self.notifier = hs.get_notifier()
+        self.config = hs.config
+
+        self.http_client = hs.get_simple_http_client()
+
+        # This is only used to get at ratelimit function, and maybe_kick_guest_users
+        self.base_handler = BaseHandler(hs)
+
+        self.pusher_pool = hs.get_pusherpool()
+
+        # We arbitrarily limit concurrent event creation for a room to 5.
+        # This is to stop us from diverging history *too* much.
+        self.limiter = Limiter(max_count=5)
+
+        self.action_generator = hs.get_action_generator()
+
+        self.spam_checker = hs.get_spam_checker()
+
+    @defer.inlineCallbacks
+    def create_event(self, requester, event_dict, token_id=None, txn_id=None,
+                     prev_events_and_hashes=None):
         """
         Given a dict from a client, create a new event.
 
@@ -185,49 +443,56 @@ class MessageHandler(BaseHandler):
         Adds display names to Join membership events.
 
         Args:
+            requester
             event_dict (dict): An entire event
             token_id (str)
             txn_id (str)
-            prev_event_ids (list): The prev event ids to use when creating the event
+
+            prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
+                the forward extremities to use as the prev_events for the
+                new event. For each event, a tuple of (event_id, hashes, depth)
+                where *hashes* is a map from algorithm to hash.
+
+                If None, they will be requested from the database.
 
         Returns:
             Tuple of created event (FrozenEvent), Context
         """
         builder = self.event_builder_factory.new(event_dict)
 
-        with (yield self.limiter.queue(builder.room_id)):
-            self.validator.validate_new(builder)
-
-            if builder.type == EventTypes.Member:
-                membership = builder.content.get("membership", None)
-                target = UserID.from_string(builder.state_key)
-
-                if membership in {Membership.JOIN, Membership.INVITE}:
-                    # If event doesn't include a display name, add one.
-                    profile = self.hs.get_handlers().profile_handler
-                    content = builder.content
-
-                    try:
-                        if "displayname" not in content:
-                            content["displayname"] = yield profile.get_displayname(target)
-                        if "avatar_url" not in content:
-                            content["avatar_url"] = yield profile.get_avatar_url(target)
-                    except Exception as e:
-                        logger.info(
-                            "Failed to get profile information for %r: %s",
-                            target, e
-                        )
+        self.validator.validate_new(builder)
+
+        if builder.type == EventTypes.Member:
+            membership = builder.content.get("membership", None)
+            target = UserID.from_string(builder.state_key)
+
+            if membership in {Membership.JOIN, Membership.INVITE}:
+                # If event doesn't include a display name, add one.
+                profile = self.profile_handler
+                content = builder.content
+
+                try:
+                    if "displayname" not in content:
+                        content["displayname"] = yield profile.get_displayname(target)
+                    if "avatar_url" not in content:
+                        content["avatar_url"] = yield profile.get_avatar_url(target)
+                except Exception as e:
+                    logger.info(
+                        "Failed to get profile information for %r: %s",
+                        target, e
+                    )
 
-            if token_id is not None:
-                builder.internal_metadata.token_id = token_id
+        if token_id is not None:
+            builder.internal_metadata.token_id = token_id
 
-            if txn_id is not None:
-                builder.internal_metadata.txn_id = txn_id
+        if txn_id is not None:
+            builder.internal_metadata.txn_id = txn_id
 
-            event, context = yield self._create_new_client_event(
-                builder=builder,
-                prev_event_ids=prev_event_ids,
-            )
+        event, context = yield self.create_new_client_event(
+            builder=builder,
+            requester=requester,
+            prev_events_and_hashes=prev_events_and_hashes,
+        )
 
         defer.returnValue((event, context))
 
@@ -248,21 +513,6 @@ class MessageHandler(BaseHandler):
                 "Tried to send member event through non-member codepath"
             )
 
-        # We check here if we are currently being rate limited, so that we
-        # don't do unnecessary work. We check again just before we actually
-        # send the event.
-        time_now = self.clock.time()
-        allowed, time_allowed = self.ratelimiter.send_message(
-            event.sender, time_now,
-            msg_rate_hz=self.hs.config.rc_messages_per_second,
-            burst_count=self.hs.config.rc_message_burst_count,
-            update=False,
-        )
-        if not allowed:
-            raise LimitExceededError(
-                retry_after_ms=int(1000 * (time_allowed - time_now)),
-            )
-
         user = UserID.from_string(event.sender)
 
         assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
@@ -279,12 +529,6 @@ class MessageHandler(BaseHandler):
             ratelimit=ratelimit,
         )
 
-        if event.type == EventTypes.Message:
-            presence = self.hs.get_presence_handler()
-            # We don't want to block sending messages on any presence code. This
-            # matters as sometimes presence code can take a while.
-            preserve_fn(presence.bump_presence_active_time)(user)
-
     @defer.inlineCallbacks
     def deduplicate_state_event(self, event, context):
         """
@@ -318,144 +562,87 @@ class MessageHandler(BaseHandler):
 
         See self.create_event and self.send_nonmember_event.
         """
-        event, context = yield self.create_event(
-            event_dict,
-            token_id=requester.access_token_id,
-            txn_id=txn_id
-        )
-        yield self.send_nonmember_event(
-            requester,
-            event,
-            context,
-            ratelimit=ratelimit,
-        )
-        defer.returnValue(event)
-
-    @defer.inlineCallbacks
-    def get_room_data(self, user_id=None, room_id=None,
-                      event_type=None, state_key="", is_guest=False):
-        """ Get data from a room.
-
-        Args:
-            event : The room path event
-        Returns:
-            The path data content.
-        Raises:
-            SynapseError if something went wrong.
-        """
-        membership, membership_event_id = yield self._check_in_room_or_world_readable(
-            room_id, user_id
-        )
 
-        if membership == Membership.JOIN:
-            data = yield self.state_handler.get_current_state(
-                room_id, event_type, state_key
+        # We limit the number of concurrent event sends in a room so that we
+        # don't fork the DAG too much. If we don't limit then we can end up in
+        # a situation where event persistence can't keep up, causing
+        # extremities to pile up, which in turn leads to state resolution
+        # taking longer.
+        with (yield self.limiter.queue(event_dict["room_id"])):
+            event, context = yield self.create_event(
+                requester,
+                event_dict,
+                token_id=requester.access_token_id,
+                txn_id=txn_id
             )
-        elif membership == Membership.LEAVE:
-            key = (event_type, state_key)
-            room_state = yield self.store.get_state_for_events(
-                [membership_event_id], [key]
-            )
-            data = room_state[membership_event_id].get(key)
 
-        defer.returnValue(data)
+            spam_error = self.spam_checker.check_event_for_spam(event)
+            if spam_error:
+                if not isinstance(spam_error, basestring):
+                    spam_error = "Spam is not permitted here"
+                raise SynapseError(
+                    403, spam_error, Codes.FORBIDDEN
+                )
 
-    @defer.inlineCallbacks
-    def _check_in_room_or_world_readable(self, room_id, user_id):
-        try:
-            # check_user_was_in_room will return the most recent membership
-            # event for the user if:
-            #  * The user is a non-guest user, and was ever in the room
-            #  * The user is a guest user, and has joined the room
-            # else it will throw.
-            member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
-            defer.returnValue((member_event.membership, member_event.event_id))
-            return
-        except AuthError:
-            visibility = yield self.state_handler.get_current_state(
-                room_id, EventTypes.RoomHistoryVisibility, ""
-            )
-            if (
-                visibility and
-                visibility.content["history_visibility"] == "world_readable"
-            ):
-                defer.returnValue((Membership.JOIN, None))
-                return
-            raise AuthError(
-                403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
+            yield self.send_nonmember_event(
+                requester,
+                event,
+                context,
+                ratelimit=ratelimit,
             )
+        defer.returnValue(event)
 
+    @measure_func("create_new_client_event")
     @defer.inlineCallbacks
-    def get_state_events(self, user_id, room_id, is_guest=False):
-        """Retrieve all state events for a given room. If the user is
-        joined to the room then return the current state. If the user has
-        left the room return the state events from when they left.
+    def create_new_client_event(self, builder, requester=None,
+                                prev_events_and_hashes=None):
+        """Create a new event for a local client
 
         Args:
-            user_id(str): The user requesting state events.
-            room_id(str): The room ID to get all state events from.
+            builder (EventBuilder):
+
+            requester (synapse.types.Requester|None):
+
+            prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
+                the forward extremities to use as the prev_events for the
+                new event. For each event, a tuple of (event_id, hashes, depth)
+                where *hashes* is a map from algorithm to hash.
+
+                If None, they will be requested from the database.
+
         Returns:
-            A list of dicts representing state events. [{}, {}, {}]
+            Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)]
         """
-        membership, membership_event_id = yield self._check_in_room_or_world_readable(
-            room_id, user_id
-        )
 
-        if membership == Membership.JOIN:
-            room_state = yield self.state_handler.get_current_state(room_id)
-        elif membership == Membership.LEAVE:
-            room_state = yield self.store.get_state_for_events(
-                [membership_event_id], None
+        if prev_events_and_hashes is not None:
+            assert len(prev_events_and_hashes) <= 10, \
+                "Attempting to create an event with %i prev_events" % (
+                    len(prev_events_and_hashes),
             )
-            room_state = room_state[membership_event_id]
-
-        now = self.clock.time_msec()
-        defer.returnValue(
-            [serialize_event(c, now) for c in room_state.values()]
-        )
-
-    @measure_func("_create_new_client_event")
-    @defer.inlineCallbacks
-    def _create_new_client_event(self, builder, prev_event_ids=None):
-        if prev_event_ids:
-            prev_events = yield self.store.add_event_hashes(prev_event_ids)
-            prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
-            depth = prev_max_depth + 1
         else:
-            latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
-                builder.room_id,
-            )
-
-            # We want to limit the max number of prev events we point to in our
-            # new event
-            if len(latest_ret) > 10:
-                # Sort by reverse depth, so we point to the most recent.
-                latest_ret.sort(key=lambda a: -a[2])
-                new_latest_ret = latest_ret[:5]
-
-                # We also randomly point to some of the older events, to make
-                # sure that we don't completely ignore the older events.
-                if latest_ret[5:]:
-                    sample_size = min(5, len(latest_ret[5:]))
-                    new_latest_ret.extend(random.sample(latest_ret[5:], sample_size))
-                latest_ret = new_latest_ret
-
-            if latest_ret:
-                depth = max([d for _, _, d in latest_ret]) + 1
-            else:
-                depth = 1
+            prev_events_and_hashes = \
+                yield self.store.get_prev_events_for_room(builder.room_id)
+
+        if prev_events_and_hashes:
+            depth = max([d for _, _, d in prev_events_and_hashes]) + 1
+            # we cap depth of generated events, to ensure that they are not
+            # rejected by other servers (and so that they can be persisted in
+            # the db)
+            depth = min(depth, MAX_DEPTH)
+        else:
+            depth = 1
 
-            prev_events = [
-                (event_id, prev_hashes)
-                for event_id, prev_hashes, _ in latest_ret
-            ]
+        prev_events = [
+            (event_id, prev_hashes)
+            for event_id, prev_hashes, _ in prev_events_and_hashes
+        ]
 
         builder.prev_events = prev_events
         builder.depth = depth
 
-        state_handler = self.state_handler
-
-        context = yield state_handler.compute_event_context(builder)
+        context = yield self.state.compute_event_context(builder)
+        if requester:
+            context.app_service = requester.app_service
 
         if builder.is_state():
             builder.prev_state = yield self.store.add_event_hashes(
@@ -488,12 +675,21 @@ class MessageHandler(BaseHandler):
         event,
         context,
         ratelimit=True,
-        extra_users=[]
+        extra_users=[],
     ):
-        # We now need to go and hit out to wherever we need to hit out to.
+        """Processes a new event. This includes checking auth, persisting it,
+        notifying users, sending to remote servers, etc.
 
-        if ratelimit:
-            self.ratelimit(requester)
+        If called from a worker will hit out to the master process for final
+        processing.
+
+        Args:
+            requester (Requester)
+            event (FrozenEvent)
+            context (EventContext)
+            ratelimit (bool)
+            extra_users (list(UserID)): Any extra users to notify about event
+        """
 
         try:
             yield self.auth.check_from_context(event, context)
@@ -501,7 +697,72 @@ class MessageHandler(BaseHandler):
             logger.warn("Denying new event %r because %s", event, err)
             raise err
 
-        yield self.maybe_kick_guest_users(event, context)
+        # Ensure that we can round trip before trying to persist in db
+        try:
+            dump = frozendict_json_encoder.encode(event.content)
+            simplejson.loads(dump)
+        except Exception:
+            logger.exception("Failed to encode content: %r", event.content)
+            raise
+
+        yield self.action_generator.handle_push_actions_for_event(
+            event, context
+        )
+
+        try:
+            # If we're a worker we need to hit out to the master.
+            if self.config.worker_app:
+                yield send_event_to_master(
+                    self.http_client,
+                    host=self.config.worker_replication_host,
+                    port=self.config.worker_replication_http_port,
+                    requester=requester,
+                    event=event,
+                    context=context,
+                    ratelimit=ratelimit,
+                    extra_users=extra_users,
+                )
+                return
+
+            yield self.persist_and_notify_client_event(
+                requester,
+                event,
+                context,
+                ratelimit=ratelimit,
+                extra_users=extra_users,
+            )
+        except:  # noqa: E722, as we reraise the exception this is fine.
+            # Ensure that we actually remove the entries in the push actions
+            # staging area, if we calculated them.
+            tp, value, tb = sys.exc_info()
+
+            run_in_background(
+                self.store.remove_push_actions_from_staging,
+                event.event_id,
+            )
+
+            six.reraise(tp, value, tb)
+
+    @defer.inlineCallbacks
+    def persist_and_notify_client_event(
+        self,
+        requester,
+        event,
+        context,
+        ratelimit=True,
+        extra_users=[],
+    ):
+        """Called when we have fully built the event, have already
+        calculated the push actions for the event, and checked auth.
+
+        This should only be run on master.
+        """
+        assert not self.config.worker_app
+
+        if ratelimit:
+            yield self.base_handler.ratelimit(requester)
+
+        yield self.base_handler.maybe_kick_guest_users(event, context)
 
         if event.type == EventTypes.CanonicalAlias:
             # Check the alias is acually valid (at this time at least)
@@ -531,9 +792,9 @@ class MessageHandler(BaseHandler):
 
                 state_to_include_ids = [
                     e_id
-                    for k, e_id in context.current_state_ids.items()
+                    for k, e_id in context.current_state_ids.iteritems()
                     if k[0] in self.hs.config.room_invite_state_types
-                    or k[0] == EventTypes.Member and k[1] == event.sender
+                    or k == (EventTypes.Member, event.sender)
                 ]
 
                 state_to_include = yield self.store.get_events(state_to_include_ids)
@@ -545,7 +806,7 @@ class MessageHandler(BaseHandler):
                         "content": e.content,
                         "sender": e.sender,
                     }
-                    for e in state_to_include.values()
+                    for e in state_to_include.itervalues()
                 ]
 
                 invitee = UserID.from_string(event.state_key)
@@ -594,30 +855,39 @@ class MessageHandler(BaseHandler):
                 "Changing the room create event is forbidden",
             )
 
-        action_generator = ActionGenerator(self.hs)
-        yield action_generator.handle_push_actions_for_event(
-            event, context
-        )
-
         (event_stream_id, max_stream_id) = yield self.store.persist_event(
             event, context=context
         )
 
         # this intentionally does not yield: we don't care about the result
         # and don't need to wait for it.
-        preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
+        run_in_background(
+            self.pusher_pool.on_new_notifications,
             event_stream_id, max_stream_id
         )
 
         @defer.inlineCallbacks
         def _notify():
             yield run_on_reactor()
-            yield self.notifier.on_new_room_event(
-                event, event_stream_id, max_stream_id,
-                extra_users=extra_users
-            )
+            try:
+                self.notifier.on_new_room_event(
+                    event, event_stream_id, max_stream_id,
+                    extra_users=extra_users
+                )
+            except Exception:
+                logger.exception("Error notifying about new room event")
 
-        preserve_fn(_notify)()
+        run_in_background(_notify)
 
-        # If invite, remove room_state from unsigned before sending.
-        event.unsigned.pop("invite_room_state", None)
+        if event.type == EventTypes.Message:
+            # We don't want to block sending messages on any presence code. This
+            # matters as sometimes presence code can take a while.
+            run_in_background(self._bump_active_time, requester.user)
+
+    @defer.inlineCallbacks
+    def _bump_active_time(self, user):
+        try:
+            presence = self.hs.get_presence_handler()
+            yield presence.bump_presence_active_time(user)
+        except Exception:
+            logger.exception("Error bumping presence active time")
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index da610e430f..585f3e4da2 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -29,7 +29,9 @@ from synapse.api.errors import SynapseError
 from synapse.api.constants import PresenceState
 from synapse.storage.presence import UserPresenceState
 
-from synapse.util.logcontext import preserve_fn
+from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.async import Linearizer
+from synapse.util.logcontext import run_in_background
 from synapse.util.logutils import log_function
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
@@ -91,29 +93,30 @@ class PresenceHandler(object):
         self.store = hs.get_datastore()
         self.wheel_timer = WheelTimer()
         self.notifier = hs.get_notifier()
-        self.replication = hs.get_replication_layer()
         self.federation = hs.get_federation_sender()
 
         self.state = hs.get_state_handler()
 
-        self.replication.register_edu_handler(
+        federation_registry = hs.get_federation_registry()
+
+        federation_registry.register_edu_handler(
             "m.presence", self.incoming_presence
         )
-        self.replication.register_edu_handler(
+        federation_registry.register_edu_handler(
             "m.presence_invite",
             lambda origin, content: self.invite_presence(
                 observed_user=UserID.from_string(content["observed_user"]),
                 observer_user=UserID.from_string(content["observer_user"]),
             )
         )
-        self.replication.register_edu_handler(
+        federation_registry.register_edu_handler(
             "m.presence_accept",
             lambda origin, content: self.accept_presence(
                 observed_user=UserID.from_string(content["observed_user"]),
                 observer_user=UserID.from_string(content["observer_user"]),
             )
         )
-        self.replication.register_edu_handler(
+        federation_registry.register_edu_handler(
             "m.presence_deny",
             lambda origin, content: self.deny_presence(
                 observed_user=UserID.from_string(content["observed_user"]),
@@ -186,6 +189,7 @@ class PresenceHandler(object):
         # process_id to millisecond timestamp last updated.
         self.external_process_to_current_syncs = {}
         self.external_process_last_updated_ms = {}
+        self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
 
         # Start a LoopingCall in 30s that fires every 5s.
         # The initial delay is to allow disconnected clients a chance to
@@ -251,6 +255,14 @@ class PresenceHandler(object):
         logger.info("Finished _persist_unpersisted_changes")
 
     @defer.inlineCallbacks
+    def _update_states_and_catch_exception(self, new_states):
+        try:
+            res = yield self._update_states(new_states)
+            defer.returnValue(res)
+        except Exception:
+            logger.exception("Error updating presence")
+
+    @defer.inlineCallbacks
     def _update_states(self, new_states):
         """Updates presence of users. Sets the appropriate timeouts. Pokes
         the notifier and federation if and only if the changed presence state
@@ -315,11 +327,7 @@ class PresenceHandler(object):
             if to_federation_ping:
                 federation_presence_out_counter.inc_by(len(to_federation_ping))
 
-                _, _, hosts_to_states = yield self._get_interested_parties(
-                    to_federation_ping.values()
-                )
-
-                self._push_to_remotes(hosts_to_states)
+                self._push_to_remotes(to_federation_ping.values())
 
     def _handle_timeouts(self):
         """Checks the presence of users that have timed out and updates as
@@ -364,8 +372,8 @@ class PresenceHandler(object):
                     now=now,
                 )
 
-            preserve_fn(self._update_states)(changes)
-        except:
+            run_in_background(self._update_states_and_catch_exception, changes)
+        except Exception:
             logger.exception("Exception in _handle_timeouts loop")
 
     @defer.inlineCallbacks
@@ -422,20 +430,23 @@ class PresenceHandler(object):
 
         @defer.inlineCallbacks
         def _end():
-            if affect_presence:
+            try:
                 self.user_to_num_current_syncs[user_id] -= 1
 
                 prev_state = yield self.current_state_for_user(user_id)
                 yield self._update_states([prev_state.copy_and_replace(
                     last_user_sync_ts=self.clock.time_msec(),
                 )])
+            except Exception:
+                logger.exception("Error updating presence after sync")
 
         @contextmanager
         def _user_syncing():
             try:
                 yield
             finally:
-                preserve_fn(_end)()
+                if affect_presence:
+                    run_in_background(_end)
 
         defer.returnValue(_user_syncing())
 
@@ -508,6 +519,73 @@ class PresenceHandler(object):
         self.external_process_to_current_syncs[process_id] = syncing_user_ids
 
     @defer.inlineCallbacks
+    def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec):
+        """Update the syncing users for an external process as a delta.
+
+        Args:
+            process_id (str): An identifier for the process the users are
+                syncing against. This allows synapse to process updates
+                as user start and stop syncing against a given process.
+            user_id (str): The user who has started or stopped syncing
+            is_syncing (bool): Whether or not the user is now syncing
+            sync_time_msec(int): Time in ms when the user was last syncing
+        """
+        with (yield self.external_sync_linearizer.queue(process_id)):
+            prev_state = yield self.current_state_for_user(user_id)
+
+            process_presence = self.external_process_to_current_syncs.setdefault(
+                process_id, set()
+            )
+
+            updates = []
+            if is_syncing and user_id not in process_presence:
+                if prev_state.state == PresenceState.OFFLINE:
+                    updates.append(prev_state.copy_and_replace(
+                        state=PresenceState.ONLINE,
+                        last_active_ts=sync_time_msec,
+                        last_user_sync_ts=sync_time_msec,
+                    ))
+                else:
+                    updates.append(prev_state.copy_and_replace(
+                        last_user_sync_ts=sync_time_msec,
+                    ))
+                process_presence.add(user_id)
+            elif user_id in process_presence:
+                updates.append(prev_state.copy_and_replace(
+                    last_user_sync_ts=sync_time_msec,
+                ))
+
+            if not is_syncing:
+                process_presence.discard(user_id)
+
+            if updates:
+                yield self._update_states(updates)
+
+            self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
+
+    @defer.inlineCallbacks
+    def update_external_syncs_clear(self, process_id):
+        """Marks all users that had been marked as syncing by a given process
+        as offline.
+
+        Used when the process has stopped/disappeared.
+        """
+        with (yield self.external_sync_linearizer.queue(process_id)):
+            process_presence = self.external_process_to_current_syncs.pop(
+                process_id, set()
+            )
+            prev_states = yield self.current_state_for_users(process_presence)
+            time_now_ms = self.clock.time_msec()
+
+            yield self._update_states([
+                prev_state.copy_and_replace(
+                    last_user_sync_ts=time_now_ms,
+                )
+                for prev_state in prev_states.itervalues()
+            ])
+            self.external_process_last_updated_ms.pop(process_id, None)
+
+    @defer.inlineCallbacks
     def current_state_for_user(self, user_id):
         """Get the current presence state for a user.
         """
@@ -526,14 +604,14 @@ class PresenceHandler(object):
             for user_id in user_ids
         }
 
-        missing = [user_id for user_id, state in states.items() if not state]
+        missing = [user_id for user_id, state in states.iteritems() if not state]
         if missing:
             # There are things not in our in memory cache. Lets pull them out of
             # the database.
             res = yield self.store.get_presence_for_users(missing)
             states.update(res)
 
-            missing = [user_id for user_id, state in states.items() if not state]
+            missing = [user_id for user_id, state in states.iteritems() if not state]
             if missing:
                 new = {
                     user_id: UserPresenceState.default(user_id)
@@ -545,89 +623,39 @@ class PresenceHandler(object):
         defer.returnValue(states)
 
     @defer.inlineCallbacks
-    def _get_interested_parties(self, states, calculate_remote_hosts=True):
-        """Given a list of states return which entities (rooms, users, servers)
-        are interested in the given states.
-
-        Returns:
-            3-tuple: `(room_ids_to_states, users_to_states, hosts_to_states)`,
-            with each item being a dict of `entity_name` -> `[UserPresenceState]`
-        """
-        room_ids_to_states = {}
-        users_to_states = {}
-        for state in states:
-            events = yield self.store.get_rooms_for_user(state.user_id)
-            for e in events:
-                room_ids_to_states.setdefault(e.room_id, []).append(state)
-
-            plist = yield self.store.get_presence_list_observers_accepted(state.user_id)
-            for u in plist:
-                users_to_states.setdefault(u, []).append(state)
-
-            # Always notify self
-            users_to_states.setdefault(state.user_id, []).append(state)
-
-        hosts_to_states = {}
-        if calculate_remote_hosts:
-            for room_id, states in room_ids_to_states.items():
-                local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
-                if not local_states:
-                    continue
-
-                users = yield self.store.get_users_in_room(room_id)
-                hosts = set(get_domain_from_id(u) for u in users)
-
-                for host in hosts:
-                    hosts_to_states.setdefault(host, []).extend(local_states)
-
-        for user_id, states in users_to_states.items():
-            local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
-            if not local_states:
-                continue
-
-            host = get_domain_from_id(user_id)
-            hosts_to_states.setdefault(host, []).extend(local_states)
-
-        # TODO: de-dup hosts_to_states, as a single host might have multiple
-        # of same presence
-
-        defer.returnValue((room_ids_to_states, users_to_states, hosts_to_states))
-
-    @defer.inlineCallbacks
     def _persist_and_notify(self, states):
         """Persist states in the database, poke the notifier and send to
         interested remote servers
         """
         stream_id, max_token = yield self.store.update_presence(states)
 
-        parties = yield self._get_interested_parties(states)
-        room_ids_to_states, users_to_states, hosts_to_states = parties
+        parties = yield get_interested_parties(self.store, states)
+        room_ids_to_states, users_to_states = parties
 
         self.notifier.on_new_event(
             "presence_key", stream_id, rooms=room_ids_to_states.keys(),
-            users=[UserID.from_string(u) for u in users_to_states.keys()]
+            users=[UserID.from_string(u) for u in users_to_states]
         )
 
-        self._push_to_remotes(hosts_to_states)
+        self._push_to_remotes(states)
 
     @defer.inlineCallbacks
     def notify_for_states(self, state, stream_id):
-        parties = yield self._get_interested_parties([state])
-        room_ids_to_states, users_to_states, hosts_to_states = parties
+        parties = yield get_interested_parties(self.store, [state])
+        room_ids_to_states, users_to_states = parties
 
         self.notifier.on_new_event(
             "presence_key", stream_id, rooms=room_ids_to_states.keys(),
-            users=[UserID.from_string(u) for u in users_to_states.keys()]
+            users=[UserID.from_string(u) for u in users_to_states]
         )
 
-    def _push_to_remotes(self, hosts_to_states):
+    def _push_to_remotes(self, states):
         """Sends state updates to remote servers.
 
         Args:
-            hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]`
+            states (list(UserPresenceState))
         """
-        for host, states in hosts_to_states.items():
-            self.federation.send_presence(host, states)
+        self.federation.send_presence(states)
 
     @defer.inlineCallbacks
     def incoming_presence(self, origin, content):
@@ -719,9 +747,7 @@ class PresenceHandler(object):
                 for state in updates
             ])
         else:
-            defer.returnValue([
-                format_user_presence_state(state, now) for state in updates
-            ])
+            defer.returnValue(updates)
 
     @defer.inlineCallbacks
     def set_state(self, target_user, state, ignore_status_msg=False):
@@ -766,18 +792,17 @@ class PresenceHandler(object):
         # don't need to send to local clients here, as that is done as part
         # of the event stream/sync.
         # TODO: Only send to servers not already in the room.
-        user_ids = yield self.store.get_users_in_room(room_id)
         if self.is_mine(user):
             state = yield self.current_state_for_user(user.to_string())
 
-            hosts = set(get_domain_from_id(u) for u in user_ids)
-            self._push_to_remotes({host: (state,) for host in hosts})
+            self._push_to_remotes([state])
         else:
+            user_ids = yield self.store.get_users_in_room(room_id)
             user_ids = filter(self.is_mine_id, user_ids)
 
             states = yield self.current_state_for_users(user_ids)
 
-            self._push_to_remotes({user.domain: states.values()})
+            self._push_to_remotes(states.values())
 
     @defer.inlineCallbacks
     def get_presence_list(self, observer_user, accepted=None):
@@ -795,6 +820,9 @@ class PresenceHandler(object):
             as_event=False,
         )
 
+        now = self.clock.time_msec()
+        results[:] = [format_user_presence_state(r, now) for r in results]
+
         is_accepted = {
             row["observed_user_id"]: row["accepted"] for row in presence_list
         }
@@ -847,6 +875,7 @@ class PresenceHandler(object):
             )
 
             state_dict = yield self.get_state(observed_user, as_event=False)
+            state_dict = format_user_presence_state(state_dict, self.clock.time_msec())
 
             self.federation.send_edu(
                 destination=observer_user.domain,
@@ -910,11 +939,12 @@ class PresenceHandler(object):
     def is_visible(self, observed_user, observer_user):
         """Returns whether a user can see another user's presence.
         """
-        observer_rooms = yield self.store.get_rooms_for_user(observer_user.to_string())
-        observed_rooms = yield self.store.get_rooms_for_user(observed_user.to_string())
-
-        observer_room_ids = set(r.room_id for r in observer_rooms)
-        observed_room_ids = set(r.room_id for r in observed_rooms)
+        observer_room_ids = yield self.store.get_rooms_for_user(
+            observer_user.to_string()
+        )
+        observed_room_ids = yield self.store.get_rooms_for_user(
+            observed_user.to_string()
+        )
 
         if observer_room_ids & observed_room_ids:
             defer.returnValue(True)
@@ -979,14 +1009,18 @@ def should_notify(old_state, new_state):
     return False
 
 
-def format_user_presence_state(state, now):
+def format_user_presence_state(state, now, include_user_id=True):
     """Convert UserPresenceState to a format that can be sent down to clients
     and to other servers.
+
+    The "user_id" is optional so that this function can be used to format presence
+    updates for client /sync responses and for federation /send requests.
     """
     content = {
         "presence": state.state,
-        "user_id": state.user_id,
     }
+    if include_user_id:
+        content["user_id"] = state.user_id
     if state.last_active_ts:
         content["last_active_ago"] = now - state.last_active_ts
     if state.status_msg and state.state != PresenceState.OFFLINE:
@@ -1025,7 +1059,6 @@ class PresenceEventSource(object):
         # sending down the rare duplicate is not a concern.
 
         with Measure(self.clock, "presence.get_new_events"):
-            user_id = user.to_string()
             if from_key is not None:
                 from_key = int(from_key)
 
@@ -1034,18 +1067,7 @@ class PresenceEventSource(object):
 
             max_token = self.store.get_current_presence_token()
 
-            plist = yield self.store.get_presence_list_accepted(user.localpart)
-            users_interested_in = set(row["observed_user_id"] for row in plist)
-            users_interested_in.add(user_id)  # So that we receive our own presence
-
-            users_who_share_room = yield self.store.get_users_who_share_room_with_user(
-                user_id
-            )
-            users_interested_in.update(users_who_share_room)
-
-            if explicit_room_id:
-                user_ids = yield self.store.get_users_in_room(explicit_room_id)
-                users_interested_in.update(user_ids)
+            users_interested_in = yield self._get_interested_in(user, explicit_room_id)
 
             user_ids_changed = set()
             changed = None
@@ -1073,16 +1095,13 @@ class PresenceEventSource(object):
 
             updates = yield presence.current_state_for_users(user_ids_changed)
 
-        now = self.clock.time_msec()
-
-        defer.returnValue(([
-            {
-                "type": "m.presence",
-                "content": format_user_presence_state(s, now),
-            }
-            for s in updates.values()
-            if include_offline or s.state != PresenceState.OFFLINE
-        ], max_token))
+        if include_offline:
+            defer.returnValue((updates.values(), max_token))
+        else:
+            defer.returnValue(([
+                s for s in updates.itervalues()
+                if s.state != PresenceState.OFFLINE
+            ], max_token))
 
     def get_current_key(self):
         return self.store.get_current_presence_token()
@@ -1090,6 +1109,31 @@ class PresenceEventSource(object):
     def get_pagination_rows(self, user, pagination_config, key):
         return self.get_new_events(user, from_key=None, include_offline=False)
 
+    @cachedInlineCallbacks(num_args=2, cache_context=True)
+    def _get_interested_in(self, user, explicit_room_id, cache_context):
+        """Returns the set of users that the given user should see presence
+        updates for
+        """
+        user_id = user.to_string()
+        plist = yield self.store.get_presence_list_accepted(
+            user.localpart, on_invalidate=cache_context.invalidate,
+        )
+        users_interested_in = set(row["observed_user_id"] for row in plist)
+        users_interested_in.add(user_id)  # So that we receive our own presence
+
+        users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+            user_id, on_invalidate=cache_context.invalidate,
+        )
+        users_interested_in.update(users_who_share_room)
+
+        if explicit_room_id:
+            user_ids = yield self.store.get_users_in_room(
+                explicit_room_id, on_invalidate=cache_context.invalidate,
+            )
+            users_interested_in.update(user_ids)
+
+        defer.returnValue(users_interested_in)
+
 
 def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
     """Checks the presence of users that have timed out and updates as
@@ -1157,14 +1201,17 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
         # If there are have been no sync for a while (and none ongoing),
         # set presence to offline
         if user_id not in syncing_user_ids:
-            if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT:
+            # If the user has done something recently but hasn't synced,
+            # don't set them as offline.
+            sync_or_active = max(state.last_user_sync_ts, state.last_active_ts)
+            if now - sync_or_active > SYNC_ONLINE_TIMEOUT:
                 state = state.copy_and_replace(
                     state=PresenceState.OFFLINE,
                     status_msg=None,
                 )
                 changed = True
     else:
-        # We expect to be poked occaisonally by the other side.
+        # We expect to be poked occasionally by the other side.
         # This is to protect against forgetful/buggy servers, so that
         # no one gets stuck online forever.
         if now - state.last_federation_update_ts > FEDERATION_TIMEOUT:
@@ -1255,3 +1302,66 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now):
         persist_and_notify = True
 
     return new_state, persist_and_notify, federation_ping
+
+
+@defer.inlineCallbacks
+def get_interested_parties(store, states):
+    """Given a list of states return which entities (rooms, users)
+    are interested in the given states.
+
+    Args:
+        states (list(UserPresenceState))
+
+    Returns:
+        2-tuple: `(room_ids_to_states, users_to_states)`,
+        with each item being a dict of `entity_name` -> `[UserPresenceState]`
+    """
+    room_ids_to_states = {}
+    users_to_states = {}
+    for state in states:
+        room_ids = yield store.get_rooms_for_user(state.user_id)
+        for room_id in room_ids:
+            room_ids_to_states.setdefault(room_id, []).append(state)
+
+        plist = yield store.get_presence_list_observers_accepted(state.user_id)
+        for u in plist:
+            users_to_states.setdefault(u, []).append(state)
+
+        # Always notify self
+        users_to_states.setdefault(state.user_id, []).append(state)
+
+    defer.returnValue((room_ids_to_states, users_to_states))
+
+
+@defer.inlineCallbacks
+def get_interested_remotes(store, states, state_handler):
+    """Given a list of presence states figure out which remote servers
+    should be sent which.
+
+    All the presence states should be for local users only.
+
+    Args:
+        store (DataStore)
+        states (list(UserPresenceState))
+
+    Returns:
+        Deferred list of ([destinations], [UserPresenceState]), where for
+        each row the list of UserPresenceState should be sent to each
+        destination
+    """
+    hosts_and_states = []
+
+    # First we look up the rooms each user is in (as well as any explicit
+    # subscriptions), then for each distinct room we look up the remote
+    # hosts in those rooms.
+    room_ids_to_states, users_to_states = yield get_interested_parties(store, states)
+
+    for room_id, states in room_ids_to_states.iteritems():
+        hosts = yield state_handler.get_current_hosts_in_room(room_id)
+        hosts_and_states.append((hosts, states))
+
+    for user_id, states in users_to_states.iteritems():
+        host = get_domain_from_id(user_id)
+        hosts_and_states.append(([host], states))
+
+    defer.returnValue(hosts_and_states)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 87f74dfb8e..3465a787ab 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -17,25 +17,87 @@ import logging
 
 from twisted.internet import defer
 
-import synapse.types
 from synapse.api.errors import SynapseError, AuthError, CodeMessageException
-from synapse.types import UserID
+from synapse.types import UserID, get_domain_from_id
 from ._base import BaseHandler
 
-
 logger = logging.getLogger(__name__)
 
 
 class ProfileHandler(BaseHandler):
+    PROFILE_UPDATE_MS = 60 * 1000
+    PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
 
     def __init__(self, hs):
         super(ProfileHandler, self).__init__(hs)
 
-        self.federation = hs.get_replication_layer()
-        self.federation.register_query_handler(
+        self.federation = hs.get_federation_client()
+        hs.get_federation_registry().register_query_handler(
             "profile", self.on_profile_query
         )
 
+        self.user_directory_handler = hs.get_user_directory_handler()
+
+        if hs.config.worker_app is None:
+            self.clock.looping_call(
+                self._update_remote_profile_cache, self.PROFILE_UPDATE_MS,
+            )
+
+    @defer.inlineCallbacks
+    def get_profile(self, user_id):
+        target_user = UserID.from_string(user_id)
+        if self.hs.is_mine(target_user):
+            displayname = yield self.store.get_profile_displayname(
+                target_user.localpart
+            )
+            avatar_url = yield self.store.get_profile_avatar_url(
+                target_user.localpart
+            )
+
+            defer.returnValue({
+                "displayname": displayname,
+                "avatar_url": avatar_url,
+            })
+        else:
+            try:
+                result = yield self.federation.make_query(
+                    destination=target_user.domain,
+                    query_type="profile",
+                    args={
+                        "user_id": user_id,
+                    },
+                    ignore_backoff=True,
+                )
+                defer.returnValue(result)
+            except CodeMessageException as e:
+                if e.code != 404:
+                    logger.exception("Failed to get displayname")
+
+                raise
+
+    @defer.inlineCallbacks
+    def get_profile_from_cache(self, user_id):
+        """Get the profile information from our local cache. If the user is
+        ours then the profile information will always be corect. Otherwise,
+        it may be out of date/missing.
+        """
+        target_user = UserID.from_string(user_id)
+        if self.hs.is_mine(target_user):
+            displayname = yield self.store.get_profile_displayname(
+                target_user.localpart
+            )
+            avatar_url = yield self.store.get_profile_avatar_url(
+                target_user.localpart
+            )
+
+            defer.returnValue({
+                "displayname": displayname,
+                "avatar_url": avatar_url,
+            })
+        else:
+            profile = yield self.store.get_from_remote_profile_cache(user_id)
+            defer.returnValue(profile or {})
+
     @defer.inlineCallbacks
     def get_displayname(self, target_user):
         if self.hs.is_mine(target_user):
@@ -52,14 +114,15 @@ class ProfileHandler(BaseHandler):
                     args={
                         "user_id": target_user.to_string(),
                         "field": "displayname",
-                    }
+                    },
+                    ignore_backoff=True,
                 )
             except CodeMessageException as e:
                 if e.code != 404:
                     logger.exception("Failed to get displayname")
 
                 raise
-            except:
+            except Exception:
                 logger.exception("Failed to get displayname")
             else:
                 defer.returnValue(result["displayname"])
@@ -81,7 +144,13 @@ class ProfileHandler(BaseHandler):
             target_user.localpart, new_displayname
         )
 
-        yield self._update_join_states(requester)
+        if self.hs.config.user_directory_search_all_users:
+            profile = yield self.store.get_profileinfo(target_user.localpart)
+            yield self.user_directory_handler.handle_local_profile_change(
+                target_user.to_string(), profile
+            )
+
+        yield self._update_join_states(requester, target_user)
 
     @defer.inlineCallbacks
     def get_avatar_url(self, target_user):
@@ -99,13 +168,14 @@ class ProfileHandler(BaseHandler):
                     args={
                         "user_id": target_user.to_string(),
                         "field": "avatar_url",
-                    }
+                    },
+                    ignore_backoff=True,
                 )
             except CodeMessageException as e:
                 if e.code != 404:
                     logger.exception("Failed to get avatar_url")
                 raise
-            except:
+            except Exception:
                 logger.exception("Failed to get avatar_url")
 
             defer.returnValue(result["avatar_url"])
@@ -124,7 +194,13 @@ class ProfileHandler(BaseHandler):
             target_user.localpart, new_avatar_url
         )
 
-        yield self._update_join_states(requester)
+        if self.hs.config.user_directory_search_all_users:
+            profile = yield self.store.get_profileinfo(target_user.localpart)
+            yield self.user_directory_handler.handle_local_profile_change(
+                target_user.to_string(), profile
+            )
+
+        yield self._update_join_states(requester, target_user)
 
     @defer.inlineCallbacks
     def on_profile_query(self, args):
@@ -149,34 +225,71 @@ class ProfileHandler(BaseHandler):
         defer.returnValue(response)
 
     @defer.inlineCallbacks
-    def _update_join_states(self, requester):
-        user = requester.user
-        if not self.hs.is_mine(user):
+    def _update_join_states(self, requester, target_user):
+        if not self.hs.is_mine(target_user):
             return
 
-        self.ratelimit(requester)
+        yield self.ratelimit(requester)
 
-        joins = yield self.store.get_rooms_for_user(
-            user.to_string(),
+        room_ids = yield self.store.get_rooms_for_user(
+            target_user.to_string(),
         )
 
-        for j in joins:
-            handler = self.hs.get_handlers().room_member_handler
+        for room_id in room_ids:
+            handler = self.hs.get_room_member_handler()
             try:
-                # Assume the user isn't a guest because we don't let guests set
-                # profile or avatar data.
-                # XXX why are we recreating `requester` here for each room?
-                # what was wrong with the `requester` we were passed?
-                requester = synapse.types.create_requester(user)
+                # Assume the target_user isn't a guest,
+                # because we don't let guests set profile or avatar data.
                 yield handler.update_membership(
                     requester,
-                    user,
-                    j.room_id,
+                    target_user,
+                    room_id,
                     "join",  # We treat a profile update like a join.
                     ratelimit=False,  # Try to hide that these events aren't atomic.
                 )
             except Exception as e:
                 logger.warn(
                     "Failed to update join event for room %s - %s",
-                    j.room_id, str(e.message)
+                    room_id, str(e.message)
                 )
+
+    def _update_remote_profile_cache(self):
+        """Called periodically to check profiles of remote users we haven't
+        checked in a while.
+        """
+        entries = yield self.store.get_remote_profile_cache_entries_that_expire(
+            last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
+        )
+
+        for user_id, displayname, avatar_url in entries:
+            is_subscribed = yield self.store.is_subscribed_remote_profile_for_user(
+                user_id,
+            )
+            if not is_subscribed:
+                yield self.store.maybe_delete_remote_profile_cache(user_id)
+                continue
+
+            try:
+                profile = yield self.federation.make_query(
+                    destination=get_domain_from_id(user_id),
+                    query_type="profile",
+                    args={
+                        "user_id": user_id,
+                    },
+                    ignore_backoff=True,
+                )
+            except Exception:
+                logger.exception("Failed to get avatar_url")
+
+                yield self.store.update_remote_profile_cache(
+                    user_id, displayname, avatar_url
+                )
+                continue
+
+            new_name = profile.get("displayname")
+            new_avatar = profile.get("avatar_url")
+
+            # We always hit update to update the last_check timestamp
+            yield self.store.update_remote_profile_cache(
+                user_id, new_name, new_avatar
+            )
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
new file mode 100644
index 0000000000..5142ae153d
--- /dev/null
+++ b/synapse/handlers/read_marker.py
@@ -0,0 +1,64 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseHandler
+
+from twisted.internet import defer
+
+from synapse.util.async import Linearizer
+
+import logging
+logger = logging.getLogger(__name__)
+
+
+class ReadMarkerHandler(BaseHandler):
+    def __init__(self, hs):
+        super(ReadMarkerHandler, self).__init__(hs)
+        self.server_name = hs.config.server_name
+        self.store = hs.get_datastore()
+        self.read_marker_linearizer = Linearizer(name="read_marker")
+        self.notifier = hs.get_notifier()
+
+    @defer.inlineCallbacks
+    def received_client_read_marker(self, room_id, user_id, event_id):
+        """Updates the read marker for a given user in a given room if the event ID given
+        is ahead in the stream relative to the current read marker.
+
+        This uses a notifier to indicate that account data should be sent down /sync if
+        the read marker has changed.
+        """
+
+        with (yield self.read_marker_linearizer.queue((room_id, user_id))):
+            existing_read_marker = yield self.store.get_account_data_for_room_and_type(
+                user_id, room_id, "m.fully_read",
+            )
+
+            should_update = True
+
+            if existing_read_marker:
+                # Only update if the new marker is ahead in the stream
+                should_update = yield self.store.is_event_after(
+                    event_id,
+                    existing_read_marker['event_id']
+                )
+
+            if should_update:
+                content = {
+                    "event_id": event_id
+                }
+                max_id = yield self.store.add_account_data_to_room(
+                    user_id, room_id, "m.fully_read", content
+                )
+                self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 50aa513935..2e0672161c 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from synapse.util import logcontext
 
 from ._base import BaseHandler
 
@@ -34,7 +35,7 @@ class ReceiptsHandler(BaseHandler):
         self.store = hs.get_datastore()
         self.hs = hs
         self.federation = hs.get_federation_sender()
-        hs.get_replication_layer().register_edu_handler(
+        hs.get_federation_registry().register_edu_handler(
             "m.receipt", self._received_remote_receipt
         )
         self.clock = self.hs.get_clock()
@@ -59,6 +60,8 @@ class ReceiptsHandler(BaseHandler):
         is_new = yield self._handle_new_receipts([receipt])
 
         if is_new:
+            # fire off a process in the background to send the receipt to
+            # remote servers
             self._push_remotes([receipt])
 
     @defer.inlineCallbacks
@@ -126,42 +129,46 @@ class ReceiptsHandler(BaseHandler):
 
             defer.returnValue(True)
 
+    @logcontext.preserve_fn   # caller should not yield on this
     @defer.inlineCallbacks
     def _push_remotes(self, receipts):
         """Given a list of receipts, works out which remote servers should be
         poked and pokes them.
         """
-        # TODO: Some of this stuff should be coallesced.
-        for receipt in receipts:
-            room_id = receipt["room_id"]
-            receipt_type = receipt["receipt_type"]
-            user_id = receipt["user_id"]
-            event_ids = receipt["event_ids"]
-            data = receipt["data"]
-
-            users = yield self.state.get_current_user_in_room(room_id)
-            remotedomains = set(get_domain_from_id(u) for u in users)
-            remotedomains = remotedomains.copy()
-            remotedomains.discard(self.server_name)
-
-            logger.debug("Sending receipt to: %r", remotedomains)
-
-            for domain in remotedomains:
-                self.federation.send_edu(
-                    destination=domain,
-                    edu_type="m.receipt",
-                    content={
-                        room_id: {
-                            receipt_type: {
-                                user_id: {
-                                    "event_ids": event_ids,
-                                    "data": data,
+        try:
+            # TODO: Some of this stuff should be coallesced.
+            for receipt in receipts:
+                room_id = receipt["room_id"]
+                receipt_type = receipt["receipt_type"]
+                user_id = receipt["user_id"]
+                event_ids = receipt["event_ids"]
+                data = receipt["data"]
+
+                users = yield self.state.get_current_user_in_room(room_id)
+                remotedomains = set(get_domain_from_id(u) for u in users)
+                remotedomains = remotedomains.copy()
+                remotedomains.discard(self.server_name)
+
+                logger.debug("Sending receipt to: %r", remotedomains)
+
+                for domain in remotedomains:
+                    self.federation.send_edu(
+                        destination=domain,
+                        edu_type="m.receipt",
+                        content={
+                            room_id: {
+                                receipt_type: {
+                                    user_id: {
+                                        "event_ids": event_ids,
+                                        "data": data,
+                                    }
                                 }
-                            }
+                            },
                         },
-                    },
-                    key=(room_id, receipt_type, user_id),
-                )
+                        key=(room_id, receipt_type, user_id),
+                    )
+        except Exception:
+            logger.exception("Error pushing receipts to remote servers")
 
     @defer.inlineCallbacks
     def get_receipts_for_room(self, room_id, to_key):
@@ -210,10 +217,9 @@ class ReceiptEventSource(object):
         else:
             from_key = None
 
-        rooms = yield self.store.get_rooms_for_user(user.to_string())
-        rooms = [room.room_id for room in rooms]
+        room_ids = yield self.store.get_rooms_for_user(user.to_string())
         events = yield self.store.get_linearized_receipts_for_rooms(
-            rooms,
+            room_ids,
             from_key=from_key,
             to_key=to_key,
         )
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 03c6a85fc6..f83c6b3cf8 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -15,7 +15,6 @@
 
 """Contains functions for registering clients."""
 import logging
-import urllib
 
 from twisted.internet import defer
 
@@ -23,8 +22,10 @@ from synapse.api.errors import (
     AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
 )
 from synapse.http.client import CaptchaServerHttpClient
-from synapse.types import UserID
-from synapse.util.async import run_on_reactor
+from synapse import types
+from synapse.types import UserID, create_requester, RoomID, RoomAlias
+from synapse.util.async import run_on_reactor, Linearizer
+from synapse.util.threepids import check_3pid_allowed
 from ._base import BaseHandler
 
 logger = logging.getLogger(__name__)
@@ -36,21 +37,33 @@ class RegistrationHandler(BaseHandler):
         super(RegistrationHandler, self).__init__(hs)
 
         self.auth = hs.get_auth()
+        self._auth_handler = hs.get_auth_handler()
+        self.profile_handler = hs.get_profile_handler()
+        self.user_directory_handler = hs.get_user_directory_handler()
         self.captcha_client = CaptchaServerHttpClient(hs)
 
         self._next_generated_user_id = None
 
         self.macaroon_gen = hs.get_macaroon_generator()
 
+        self._generate_user_id_linearizer = Linearizer(
+            name="_generate_user_id_linearizer",
+        )
+
     @defer.inlineCallbacks
     def check_username(self, localpart, guest_access_token=None,
                        assigned_user_id=None):
-        yield run_on_reactor()
+        if types.contains_invalid_mxid_characters(localpart):
+            raise SynapseError(
+                400,
+                "User ID can only contain characters a-z, 0-9, or '=_-./'",
+                Codes.INVALID_USERNAME
+            )
 
-        if urllib.quote(localpart.encode('utf-8')) != localpart:
+        if not localpart:
             raise SynapseError(
                 400,
-                "User ID can only contain characters a-z, 0-9, or '_-./'",
+                "User ID cannot be empty",
                 Codes.INVALID_USERNAME
             )
 
@@ -73,7 +86,7 @@ class RegistrationHandler(BaseHandler):
                     "A different user ID has already been registered for this session",
                 )
 
-        yield self.check_user_id_not_appservice_exclusive(user_id)
+        self.check_user_id_not_appservice_exclusive(user_id)
 
         users = yield self.store.get_users_by_id_case_insensitive(user_id)
         if users:
@@ -123,7 +136,7 @@ class RegistrationHandler(BaseHandler):
         yield run_on_reactor()
         password_hash = None
         if password:
-            password_hash = self.auth_handler().hash(password)
+            password_hash = yield self.auth_handler().hash(password)
 
         if localpart:
             yield self.check_username(localpart, guest_access_token=guest_access_token)
@@ -158,6 +171,13 @@ class RegistrationHandler(BaseHandler):
                 ),
                 admin=admin,
             )
+
+            if self.hs.config.user_directory_search_all_users:
+                profile = yield self.store.get_profileinfo(localpart)
+                yield self.user_directory_handler.handle_local_profile_change(
+                    user_id, profile
+                )
+
         else:
             # autogen a sequential user ID
             attempts = 0
@@ -185,10 +205,17 @@ class RegistrationHandler(BaseHandler):
                     token = None
                     attempts += 1
 
+        # auto-join the user to any rooms we're supposed to dump them into
+        fake_requester = create_requester(user_id)
+        for r in self.hs.config.auto_join_rooms:
+            try:
+                yield self._join_user_to_room(fake_requester, r)
+            except Exception as e:
+                logger.error("Failed to join new user to %r: %r", r, e)
+
         # We used to generate default identicons here, but nowadays
         # we want clients to generate their own as part of their branding
         # rather than there being consistent matrix-wide ones, so we don't.
-
         defer.returnValue((user_id, token))
 
     @defer.inlineCallbacks
@@ -246,11 +273,10 @@ class RegistrationHandler(BaseHandler):
         """
         Registers email_id as SAML2 Based Auth.
         """
-        if urllib.quote(localpart) != localpart:
+        if types.contains_invalid_mxid_characters(localpart):
             raise SynapseError(
                 400,
-                "User ID must only contain characters which do not"
-                " require URL encoding."
+                "User ID can only contain characters a-z, 0-9, or '=_-./'",
             )
         user = UserID(localpart, self.hs.hostname)
         user_id = user.to_string()
@@ -279,12 +305,12 @@ class RegistrationHandler(BaseHandler):
         """
 
         for c in threepidCreds:
-            logger.info("validating theeepidcred sid %s on id server %s",
+            logger.info("validating threepidcred sid %s on id server %s",
                         c['sid'], c['idServer'])
             try:
                 identity_handler = self.hs.get_handlers().identity_handler
                 threepid = yield identity_handler.threepid_from_creds(c)
-            except:
+            except Exception:
                 logger.exception("Couldn't validate 3pid")
                 raise RegistrationError(400, "Couldn't validate 3pid")
 
@@ -293,6 +319,11 @@ class RegistrationHandler(BaseHandler):
             logger.info("got threepid with medium '%s' and address '%s'",
                         threepid['medium'], threepid['address'])
 
+            if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']):
+                raise RegistrationError(
+                    403, "Third party identifier is not allowed"
+                )
+
     @defer.inlineCallbacks
     def bind_emails(self, user_id, threepidCreds):
         """Links emails with a user ID and informs an identity server.
@@ -325,9 +356,11 @@ class RegistrationHandler(BaseHandler):
     @defer.inlineCallbacks
     def _generate_user_id(self, reseed=False):
         if reseed or self._next_generated_user_id is None:
-            self._next_generated_user_id = (
-                yield self.store.find_next_generated_user_id_localpart()
-            )
+            with (yield self._generate_user_id_linearizer.queue(())):
+                if reseed or self._next_generated_user_id is None:
+                    self._next_generated_user_id = (
+                        yield self.store.find_next_generated_user_id_localpart()
+                    )
 
         id = self._next_generated_user_id
         self._next_generated_user_id += 1
@@ -411,13 +444,12 @@ class RegistrationHandler(BaseHandler):
                 create_profile_with_localpart=user.localpart,
             )
         else:
-            yield self.store.user_delete_access_tokens(user_id=user_id)
+            yield self._auth_handler.delete_access_tokens_for_user(user_id)
             yield self.store.add_access_token_to_user(user_id=user_id, token=token)
 
         if displayname is not None:
             logger.info("setting user display name: %s -> %s", user_id, displayname)
-            profile_handler = self.hs.get_handlers().profile_handler
-            yield profile_handler.set_displayname(
+            yield self.profile_handler.set_displayname(
                 user, requester, displayname, by_admin=True,
             )
 
@@ -427,16 +459,59 @@ class RegistrationHandler(BaseHandler):
         return self.hs.get_auth_handler()
 
     @defer.inlineCallbacks
-    def guest_access_token_for(self, medium, address, inviter_user_id):
+    def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
+        """Get a guest access token for a 3PID, creating a guest account if
+        one doesn't already exist.
+
+        Args:
+            medium (str)
+            address (str)
+            inviter_user_id (str): The user ID who is trying to invite the
+                3PID
+
+        Returns:
+            Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
+            3PID guest account.
+        """
         access_token = yield self.store.get_3pid_guest_access_token(medium, address)
         if access_token:
-            defer.returnValue(access_token)
+            user_info = yield self.auth.get_user_by_access_token(
+                access_token
+            )
 
-        _, access_token = yield self.register(
+            defer.returnValue((user_info["user"].to_string(), access_token))
+
+        user_id, access_token = yield self.register(
             generate_token=True,
             make_guest=True
         )
         access_token = yield self.store.save_or_get_3pid_guest_access_token(
             medium, address, access_token, inviter_user_id
         )
-        defer.returnValue(access_token)
+
+        defer.returnValue((user_id, access_token))
+
+    @defer.inlineCallbacks
+    def _join_user_to_room(self, requester, room_identifier):
+        room_id = None
+        room_member_handler = self.hs.get_room_member_handler()
+        if RoomID.is_valid(room_identifier):
+            room_id = room_identifier
+        elif RoomAlias.is_valid(room_identifier):
+            room_alias = RoomAlias.from_string(room_identifier)
+            room_id, remote_room_hosts = (
+                yield room_member_handler.lookup_room_alias(room_alias)
+            )
+            room_id = room_id.to_string()
+        else:
+            raise SynapseError(400, "%s was not legal room ID or room alias" % (
+                room_identifier,
+            ))
+
+        yield room_member_handler.update_membership(
+            requester=requester,
+            target=requester.user,
+            room_id=room_id,
+            remote_room_hosts=remote_room_hosts,
+            action="join",
+        )
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 7e7671c9a2..8df8fcbbad 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -60,8 +61,14 @@ class RoomCreationHandler(BaseHandler):
         },
     }
 
+    def __init__(self, hs):
+        super(RoomCreationHandler, self).__init__(hs)
+
+        self.spam_checker = hs.get_spam_checker()
+        self.event_creation_handler = hs.get_event_creation_handler()
+
     @defer.inlineCallbacks
-    def create_room(self, requester, config):
+    def create_room(self, requester, config, ratelimit=True):
         """ Creates a new room.
 
         Args:
@@ -75,14 +82,18 @@ class RoomCreationHandler(BaseHandler):
         """
         user_id = requester.user.to_string()
 
-        self.ratelimit(requester)
+        if not self.spam_checker.user_may_create_room(user_id):
+            raise SynapseError(403, "You are not permitted to create rooms")
+
+        if ratelimit:
+            yield self.ratelimit(requester)
 
         if "room_alias_name" in config:
             for wchar in string.whitespace:
                 if wchar in config["room_alias_name"]:
                     raise SynapseError(400, "Invalid characters in room alias")
 
-            room_alias = RoomAlias.create(
+            room_alias = RoomAlias(
                 config["room_alias_name"],
                 self.hs.hostname,
             )
@@ -99,7 +110,7 @@ class RoomCreationHandler(BaseHandler):
         for i in invite_list:
             try:
                 UserID.from_string(i)
-            except:
+            except Exception:
                 raise SynapseError(400, "Invalid user_id: %s" % (i,))
 
         invite_3pid_list = config.get("invite_3pid", [])
@@ -114,7 +125,7 @@ class RoomCreationHandler(BaseHandler):
         while attempts < 5:
             try:
                 random_string = stringutils.random_string(18)
-                gen_room_id = RoomID.create(
+                gen_room_id = RoomID(
                     random_string,
                     self.hs.hostname,
                 )
@@ -154,24 +165,23 @@ class RoomCreationHandler(BaseHandler):
 
         creation_content = config.get("creation_content", {})
 
-        msg_handler = self.hs.get_handlers().message_handler
-        room_member_handler = self.hs.get_handlers().room_member_handler
+        room_member_handler = self.hs.get_room_member_handler()
 
         yield self._send_events_for_new_room(
             requester,
             room_id,
-            msg_handler,
             room_member_handler,
             preset_config=preset_config,
             invite_list=invite_list,
             initial_state=initial_state,
             creation_content=creation_content,
             room_alias=room_alias,
+            power_level_content_override=config.get("power_level_content_override", {})
         )
 
         if "name" in config:
             name = config["name"]
-            yield msg_handler.create_and_send_nonmember_event(
+            yield self.event_creation_handler.create_and_send_nonmember_event(
                 requester,
                 {
                     "type": EventTypes.Name,
@@ -184,7 +194,7 @@ class RoomCreationHandler(BaseHandler):
 
         if "topic" in config:
             topic = config["topic"]
-            yield msg_handler.create_and_send_nonmember_event(
+            yield self.event_creation_handler.create_and_send_nonmember_event(
                 requester,
                 {
                     "type": EventTypes.Topic,
@@ -195,12 +205,12 @@ class RoomCreationHandler(BaseHandler):
                 },
                 ratelimit=False)
 
-        content = {}
-        is_direct = config.get("is_direct", None)
-        if is_direct:
-            content["is_direct"] = is_direct
-
         for invitee in invite_list:
+            content = {}
+            is_direct = config.get("is_direct", None)
+            if is_direct:
+                content["is_direct"] = is_direct
+
             yield room_member_handler.update_membership(
                 requester,
                 UserID.from_string(invitee),
@@ -214,7 +224,7 @@ class RoomCreationHandler(BaseHandler):
             id_server = invite_3pid["id_server"]
             address = invite_3pid["address"]
             medium = invite_3pid["medium"]
-            yield self.hs.get_handlers().room_member_handler.do_3pid_invite(
+            yield self.hs.get_room_member_handler().do_3pid_invite(
                 room_id,
                 requester.user,
                 medium,
@@ -239,13 +249,13 @@ class RoomCreationHandler(BaseHandler):
             self,
             creator,  # A Requester object.
             room_id,
-            msg_handler,
             room_member_handler,
             preset_config,
             invite_list,
             initial_state,
             creation_content,
-            room_alias
+            room_alias,
+            power_level_content_override,
     ):
         def create(etype, content, **kwargs):
             e = {
@@ -261,7 +271,7 @@ class RoomCreationHandler(BaseHandler):
         @defer.inlineCallbacks
         def send(etype, content, **kwargs):
             event = create(etype, content, **kwargs)
-            yield msg_handler.create_and_send_nonmember_event(
+            yield self.event_creation_handler.create_and_send_nonmember_event(
                 creator,
                 event,
                 ratelimit=False
@@ -291,7 +301,15 @@ class RoomCreationHandler(BaseHandler):
             ratelimit=False,
         )
 
-        if (EventTypes.PowerLevels, '') not in initial_state:
+        # We treat the power levels override specially as this needs to be one
+        # of the first events that get sent into a room.
+        pl_content = initial_state.pop((EventTypes.PowerLevels, ''), None)
+        if pl_content is not None:
+            yield send(
+                etype=EventTypes.PowerLevels,
+                content=pl_content,
+            )
+        else:
             power_level_content = {
                 "users": {
                     creator_id: 100,
@@ -316,6 +334,8 @@ class RoomCreationHandler(BaseHandler):
                 for invitee in invite_list:
                     power_level_content["users"][invitee] = 100
 
+            power_level_content.update(power_level_content_override)
+
             yield send(
                 etype=EventTypes.PowerLevels,
                 content=power_level_content,
@@ -356,7 +376,7 @@ class RoomCreationHandler(BaseHandler):
 
 class RoomContextHandler(BaseHandler):
     @defer.inlineCallbacks
-    def get_event_context(self, user, room_id, event_id, limit, is_guest):
+    def get_event_context(self, user, room_id, event_id, limit):
         """Retrieves events, pagination tokens and state around a given event
         in a room.
 
@@ -375,12 +395,15 @@ class RoomContextHandler(BaseHandler):
 
         now_token = yield self.hs.get_event_sources().get_current_token()
 
+        users = yield self.store.get_users_in_room(room_id)
+        is_peeking = user.to_string() not in users
+
         def filter_evts(events):
             return filter_events_for_client(
                 self.store,
                 user.to_string(),
                 events,
-                is_peeking=is_guest
+                is_peeking=is_peeking
             )
 
         event = yield self.store.get_event(event_id, get_prev_content=True,
@@ -452,12 +475,9 @@ class RoomEventSource(object):
             user.to_string()
         )
         if app_service:
-            events, end_key = yield self.store.get_appservice_room_stream(
-                service=app_service,
-                from_key=from_key,
-                to_key=to_key,
-                limit=limit,
-            )
+            # We no longer support AS users using /sync directly.
+            # See https://github.com/matrix-org/matrix-doc/issues/1144
+            raise NotImplementedError()
         else:
             room_events = yield self.store.get_membership_changes_for_user(
                 user.to_string(), from_key, to_key
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 19eebbd43f..5757bb7f8a 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -15,12 +15,15 @@
 
 from twisted.internet import defer
 
+from six.moves import range
+
 from ._base import BaseHandler
 
 from synapse.api.constants import (
     EventTypes, JoinRules,
 )
 from synapse.util.async import concurrently_execute
+from synapse.util.caches.descriptors import cachedInlineCallbacks
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.types import ThirdPartyInstanceID
 
@@ -42,8 +45,9 @@ EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
 class RoomListHandler(BaseHandler):
     def __init__(self, hs):
         super(RoomListHandler, self).__init__(hs)
-        self.response_cache = ResponseCache(hs)
-        self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000)
+        self.response_cache = ResponseCache(hs, "room_list")
+        self.remote_response_cache = ResponseCache(hs, "remote_room_list",
+                                                   timeout_ms=30 * 1000)
 
     def get_local_public_room_list(self, limit=None, since_token=None,
                                    search_filter=None,
@@ -62,23 +66,24 @@ class RoomListHandler(BaseHandler):
                 appservice and network id to use an appservice specific one.
                 Setting to None returns all public rooms across all lists.
         """
+        logger.info(
+            "Getting public room list: limit=%r, since=%r, search=%r, network=%r",
+            limit, since_token, bool(search_filter), network_tuple,
+        )
         if search_filter:
             # We explicitly don't bother caching searches or requests for
             # appservice specific lists.
+            logger.info("Bypassing cache as search request.")
             return self._get_public_room_list(
                 limit, since_token, search_filter, network_tuple=network_tuple,
             )
 
         key = (limit, since_token, network_tuple)
-        result = self.response_cache.get(key)
-        if not result:
-            result = self.response_cache.set(
-                key,
-                self._get_public_room_list(
-                    limit, since_token, network_tuple=network_tuple
-                )
-            )
-        return result
+        return self.response_cache.wrap(
+            key,
+            self._get_public_room_list,
+            limit, since_token, network_tuple=network_tuple,
+        )
 
     @defer.inlineCallbacks
     def _get_public_room_list(self, limit=None, since_token=None,
@@ -91,7 +96,6 @@ class RoomListHandler(BaseHandler):
 
         rooms_to_order_value = {}
         rooms_to_num_joined = {}
-        rooms_to_latest_event_ids = {}
 
         newly_visible = []
         newly_unpublished = []
@@ -116,19 +120,26 @@ class RoomListHandler(BaseHandler):
 
         @defer.inlineCallbacks
         def get_order_for_room(room_id):
-            latest_event_ids = rooms_to_latest_event_ids.get(room_id, None)
-            if not latest_event_ids:
+            # Most of the rooms won't have changed between the since token and
+            # now (especially if the since token is "now"). So, we can ask what
+            # the current users are in a room (that will hit a cache) and then
+            # check if the room has changed since the since token. (We have to
+            # do it in that order to avoid races).
+            # If things have changed then fall back to getting the current state
+            # at the since token.
+            joined_users = yield self.store.get_users_in_room(room_id)
+            if self.store.has_room_changed_since(room_id, stream_token):
                 latest_event_ids = yield self.store.get_forward_extremeties_for_room(
                     room_id, stream_token
                 )
-                rooms_to_latest_event_ids[room_id] = latest_event_ids
 
-            if not latest_event_ids:
-                return
+                if not latest_event_ids:
+                    return
+
+                joined_users = yield self.state_handler.get_current_user_in_room(
+                    room_id, latest_event_ids,
+                )
 
-            joined_users = yield self.state_handler.get_current_user_in_room(
-                room_id, latest_event_ids,
-            )
             num_joined_users = len(joined_users)
             rooms_to_num_joined[room_id] = num_joined_users
 
@@ -138,6 +149,8 @@ class RoomListHandler(BaseHandler):
             # We want larger rooms to be first, hence negating num_joined_users
             rooms_to_order_value[room_id] = (-num_joined_users, room_id)
 
+        logger.info("Getting ordering for %i rooms since %s",
+                    len(room_ids), stream_token)
         yield concurrently_execute(get_order_for_room, room_ids, 10)
 
         sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
@@ -165,34 +178,43 @@ class RoomListHandler(BaseHandler):
                 rooms_to_scan = rooms_to_scan[:since_token.current_limit]
                 rooms_to_scan.reverse()
 
-        # Actually generate the entries. _generate_room_entry will append to
-        # chunk but will stop if len(chunk) > limit
-        chunk = []
-        if limit and not search_filter:
+        logger.info("After sorting and filtering, %i rooms remain",
+                    len(rooms_to_scan))
+
+        # _append_room_entry_to_chunk will append to chunk but will stop if
+        # len(chunk) > limit
+        #
+        # Normally we will generate enough results on the first iteration here,
+        #  but if there is a search filter, _append_room_entry_to_chunk may
+        # filter some results out, in which case we loop again.
+        #
+        # We don't want to scan over the entire range either as that
+        # would potentially waste a lot of work.
+        #
+        # XXX if there is no limit, we may end up DoSing the server with
+        # calls to get_current_state_ids for every single room on the
+        # server. Surely we should cap this somehow?
+        #
+        if limit:
             step = limit + 1
-            for i in xrange(0, len(rooms_to_scan), step):
-                # We iterate here because the vast majority of cases we'll stop
-                # at first iteration, but occaisonally _generate_room_entry
-                # won't append to the chunk and so we need to loop again.
-                # We don't want to scan over the entire range either as that
-                # would potentially waste a lot of work.
-                yield concurrently_execute(
-                    lambda r: self._generate_room_entry(
-                        r, rooms_to_num_joined[r],
-                        chunk, limit, search_filter
-                    ),
-                    rooms_to_scan[i:i + step], 10
-                )
-                if len(chunk) >= limit + 1:
-                    break
         else:
+            # step cannot be zero
+            step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
+
+        chunk = []
+        for i in range(0, len(rooms_to_scan), step):
+            batch = rooms_to_scan[i:i + step]
+            logger.info("Processing %i rooms for result", len(batch))
             yield concurrently_execute(
-                lambda r: self._generate_room_entry(
+                lambda r: self._append_room_entry_to_chunk(
                     r, rooms_to_num_joined[r],
                     chunk, limit, search_filter
                 ),
-                rooms_to_scan, 5
+                batch, 5,
             )
+            logger.info("Now %i rooms in result", len(chunk))
+            if len(chunk) >= limit + 1:
+                break
 
         chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"]))
 
@@ -256,21 +278,36 @@ class RoomListHandler(BaseHandler):
         defer.returnValue(results)
 
     @defer.inlineCallbacks
-    def _generate_room_entry(self, room_id, num_joined_users, chunk, limit,
-                             search_filter):
+    def _append_room_entry_to_chunk(self, room_id, num_joined_users, chunk, limit,
+                                    search_filter):
+        """Generate the entry for a room in the public room list and append it
+        to the `chunk` if it matches the search filter
+        """
         if limit and len(chunk) > limit + 1:
             # We've already got enough, so lets just drop it.
             return
 
+        result = yield self.generate_room_entry(room_id, num_joined_users)
+
+        if result and _matches_room_entry(result, search_filter):
+            chunk.append(result)
+
+    @cachedInlineCallbacks(num_args=1, cache_context=True)
+    def generate_room_entry(self, room_id, num_joined_users, cache_context,
+                            with_alias=True, allow_private=False):
+        """Returns the entry for a room
+        """
         result = {
             "room_id": room_id,
             "num_joined_members": num_joined_users,
         }
 
-        current_state_ids = yield self.state_handler.get_current_state_ids(room_id)
+        current_state_ids = yield self.store.get_current_state_ids(
+            room_id, on_invalidate=cache_context.invalidate,
+        )
 
         event_map = yield self.store.get_events([
-            event_id for key, event_id in current_state_ids.items()
+            event_id for key, event_id in current_state_ids.iteritems()
             if key[0] in (
                 EventTypes.JoinRules,
                 EventTypes.Name,
@@ -291,12 +328,15 @@ class RoomListHandler(BaseHandler):
         join_rules_event = current_state.get((EventTypes.JoinRules, ""))
         if join_rules_event:
             join_rule = join_rules_event.content.get("join_rule", None)
-            if join_rule and join_rule != JoinRules.PUBLIC:
+            if not allow_private and join_rule and join_rule != JoinRules.PUBLIC:
                 defer.returnValue(None)
 
-        aliases = yield self.store.get_aliases_for_room(room_id)
-        if aliases:
-            result["aliases"] = aliases
+        if with_alias:
+            aliases = yield self.store.get_aliases_for_room(
+                room_id, on_invalidate=cache_context.invalidate
+            )
+            if aliases:
+                result["aliases"] = aliases
 
         name_event = yield current_state.get((EventTypes.Name, ""))
         if name_event:
@@ -334,8 +374,7 @@ class RoomListHandler(BaseHandler):
             if avatar_url:
                 result["avatar_url"] = avatar_url
 
-        if _matches_room_entry(result, search_filter):
-            chunk.append(result)
+        defer.returnValue(result)
 
     @defer.inlineCallbacks
     def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
@@ -365,7 +404,7 @@ class RoomListHandler(BaseHandler):
     def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
                                 search_filter=None, include_all_networks=False,
                                 third_party_instance_id=None,):
-        repl_layer = self.hs.get_replication_layer()
+        repl_layer = self.hs.get_federation_client()
         if search_filter:
             # We can't cache when asking for search
             return repl_layer.get_public_rooms(
@@ -378,18 +417,14 @@ class RoomListHandler(BaseHandler):
             server_name, limit, since_token, include_all_networks,
             third_party_instance_id,
         )
-        result = self.remote_response_cache.get(key)
-        if not result:
-            result = self.remote_response_cache.set(
-                key,
-                repl_layer.get_public_rooms(
-                    server_name, limit=limit, since_token=since_token,
-                    search_filter=search_filter,
-                    include_all_networks=include_all_networks,
-                    third_party_instance_id=third_party_instance_id,
-                )
-            )
-        return result
+        return self.remote_response_cache.wrap(
+            key,
+            repl_layer.get_public_rooms,
+            server_name, limit=limit, since_token=since_token,
+            search_filter=search_filter,
+            include_all_networks=include_all_networks,
+            third_party_instance_id=third_party_instance_id,
+        )
 
 
 class RoomListNextBatch(namedtuple("RoomListNextBatch", (
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index b2806555cf..714583f1d5 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
+import abc
 import logging
 
 from signedjson.key import decode_verify_key_bytes
@@ -29,47 +30,139 @@ from synapse.api.errors import AuthError, SynapseError, Codes
 from synapse.types import UserID, RoomID
 from synapse.util.async import Linearizer
 from synapse.util.distributor import user_left_room, user_joined_room
-from ._base import BaseHandler
+
 
 logger = logging.getLogger(__name__)
 
 id_server_scheme = "https://"
 
 
-class RoomMemberHandler(BaseHandler):
+class RoomMemberHandler(object):
     # TODO(paul): This handler currently contains a messy conflation of
     #   low-level API that works on UserID objects and so on, and REST-level
     #   API that takes ID strings and returns pagination chunks. These concerns
     #   ought to be separated out a lot better.
 
+    __metaclass__ = abc.ABCMeta
+
     def __init__(self, hs):
-        super(RoomMemberHandler, self).__init__(hs)
+        self.hs = hs
+        self.store = hs.get_datastore()
+        self.auth = hs.get_auth()
+        self.state_handler = hs.get_state_handler()
+        self.config = hs.config
+        self.simple_http_client = hs.get_simple_http_client()
+
+        self.federation_handler = hs.get_handlers().federation_handler
+        self.directory_handler = hs.get_handlers().directory_handler
+        self.registration_handler = hs.get_handlers().registration_handler
+        self.profile_handler = hs.get_profile_handler()
+        self.event_creation_hander = hs.get_event_creation_handler()
 
         self.member_linearizer = Linearizer(name="member")
 
         self.clock = hs.get_clock()
+        self.spam_checker = hs.get_spam_checker()
 
-        self.distributor = hs.get_distributor()
-        self.distributor.declare("user_joined_room")
-        self.distributor.declare("user_left_room")
+    @abc.abstractmethod
+    def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+        """Try and join a room that this server is not in
+
+        Args:
+            requester (Requester)
+            remote_room_hosts (list[str]): List of servers that can be used
+                to join via.
+            room_id (str): Room that we are trying to join
+            user (UserID): User who is trying to join
+            content (dict): A dict that should be used as the content of the
+                join event.
+
+        Returns:
+            Deferred
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def _remote_reject_invite(self, remote_room_hosts, room_id, target):
+        """Attempt to reject an invite for a room this server is not in. If we
+        fail to do so we locally mark the invite as rejected.
+
+        Args:
+            requester (Requester)
+            remote_room_hosts (list[str]): List of servers to use to try and
+                reject invite
+            room_id (str)
+            target (UserID): The user rejecting the invite
+
+        Returns:
+            Deferred[dict]: A dictionary to be returned to the client, may
+            include event_id etc, or nothing if we locally rejected
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
+        """Get a guest access token for a 3PID, creating a guest account if
+        one doesn't already exist.
+
+        Args:
+            requester (Requester)
+            medium (str)
+            address (str)
+            inviter_user_id (str): The user ID who is trying to invite the
+                3PID
+
+        Returns:
+            Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
+            3PID guest account.
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def _user_joined_room(self, target, room_id):
+        """Notifies distributor on master process that the user has joined the
+        room.
+
+        Args:
+            target (UserID)
+            room_id (str)
+
+        Returns:
+            Deferred|None
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def _user_left_room(self, target, room_id):
+        """Notifies distributor on master process that the user has left the
+        room.
+
+        Args:
+            target (UserID)
+            room_id (str)
+
+        Returns:
+            Deferred|None
+        """
+        raise NotImplementedError()
 
     @defer.inlineCallbacks
     def _local_membership_update(
         self, requester, target, room_id, membership,
-        prev_event_ids,
+        prev_events_and_hashes,
         txn_id=None,
         ratelimit=True,
         content=None,
     ):
         if content is None:
             content = {}
-        msg_handler = self.hs.get_handlers().message_handler
 
         content["membership"] = membership
         if requester.is_guest:
             content["kind"] = "guest"
 
-        event, context = yield msg_handler.create_event(
+        event, context = yield self.event_creation_hander.create_event(
+            requester,
             {
                 "type": EventTypes.Member,
                 "content": content,
@@ -82,16 +175,18 @@ class RoomMemberHandler(BaseHandler):
             },
             token_id=requester.access_token_id,
             txn_id=txn_id,
-            prev_event_ids=prev_event_ids,
+            prev_events_and_hashes=prev_events_and_hashes,
         )
 
         # Check if this event matches the previous membership event for the user.
-        duplicate = yield msg_handler.deduplicate_state_event(event, context)
+        duplicate = yield self.event_creation_hander.deduplicate_state_event(
+            event, context,
+        )
         if duplicate is not None:
             # Discard the new event since this membership change is a no-op.
             defer.returnValue(duplicate)
 
-        yield msg_handler.handle_new_client_event(
+        yield self.event_creation_hander.handle_new_client_event(
             requester,
             event,
             context,
@@ -113,40 +208,16 @@ class RoomMemberHandler(BaseHandler):
                 prev_member_event = yield self.store.get_event(prev_member_event_id)
                 newly_joined = prev_member_event.membership != Membership.JOIN
             if newly_joined:
-                yield user_joined_room(self.distributor, target, room_id)
+                yield self._user_joined_room(target, room_id)
         elif event.membership == Membership.LEAVE:
             if prev_member_event_id:
                 prev_member_event = yield self.store.get_event(prev_member_event_id)
                 if prev_member_event.membership == Membership.JOIN:
-                    user_left_room(self.distributor, target, room_id)
+                    yield self._user_left_room(target, room_id)
 
         defer.returnValue(event)
 
     @defer.inlineCallbacks
-    def remote_join(self, remote_room_hosts, room_id, user, content):
-        if len(remote_room_hosts) == 0:
-            raise SynapseError(404, "No known servers")
-
-        # We don't do an auth check if we are doing an invite
-        # join dance for now, since we're kinda implicitly checking
-        # that we are allowed to join when we decide whether or not we
-        # need to do the invite/join dance.
-        yield self.hs.get_handlers().federation_handler.do_invite_join(
-            remote_room_hosts,
-            room_id,
-            user.to_string(),
-            content,
-        )
-        yield user_joined_room(self.distributor, user, room_id)
-
-    def reject_remote_invite(self, user_id, room_id, remote_room_hosts):
-        return self.hs.get_handlers().federation_handler.do_remotely_reject_invite(
-            remote_room_hosts,
-            room_id,
-            user_id
-        )
-
-    @defer.inlineCallbacks
     def update_membership(
             self,
             requester,
@@ -192,14 +263,19 @@ class RoomMemberHandler(BaseHandler):
         content_specified = bool(content)
         if content is None:
             content = {}
+        else:
+            # We do a copy here as we potentially change some keys
+            # later on.
+            content = dict(content)
 
         effective_membership_state = action
         if action in ["kick", "unban"]:
             effective_membership_state = "leave"
 
+        # if this is a join with a 3pid signature, we may need to turn a 3pid
+        # invite into a normal invite before we can handle the join.
         if third_party_signed is not None:
-            replication = self.hs.get_replication_layer()
-            yield replication.exchange_third_party_invite(
+            yield self.federation_handler.exchange_third_party_invite(
                 third_party_signed["sender"],
                 target.to_string(),
                 room_id,
@@ -209,7 +285,41 @@ class RoomMemberHandler(BaseHandler):
         if not remote_room_hosts:
             remote_room_hosts = []
 
-        latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+        if effective_membership_state not in ("leave", "ban",):
+            is_blocked = yield self.store.is_room_blocked(room_id)
+            if is_blocked:
+                raise SynapseError(403, "This room has been blocked on this server")
+
+        if effective_membership_state == "invite":
+            block_invite = False
+            is_requester_admin = yield self.auth.is_server_admin(
+                requester.user,
+            )
+            if not is_requester_admin:
+                if self.config.block_non_admin_invites:
+                    logger.info(
+                        "Blocking invite: user is not admin and non-admin "
+                        "invites disabled"
+                    )
+                    block_invite = True
+
+                if not self.spam_checker.user_may_invite(
+                    requester.user.to_string(), target.to_string(), room_id,
+                ):
+                    logger.info("Blocking invite due to spam checker")
+                    block_invite = True
+
+            if block_invite:
+                raise SynapseError(
+                    403, "Invites have been disabled on this server",
+                )
+
+        prev_events_and_hashes = yield self.store.get_prev_events_for_room(
+            room_id,
+        )
+        latest_event_ids = (
+            event_id for (event_id, _, _) in prev_events_and_hashes
+        )
         current_state_ids = yield self.state_handler.get_current_state_ids(
             room_id, latest_event_ids=latest_event_ids,
         )
@@ -250,13 +360,13 @@ class RoomMemberHandler(BaseHandler):
                     raise AuthError(403, "Guest access not allowed")
 
             if not is_host_in_room:
-                inviter = yield self.get_inviter(target.to_string(), room_id)
+                inviter = yield self._get_inviter(target.to_string(), room_id)
                 if inviter and not self.hs.is_mine(inviter):
                     remote_room_hosts.append(inviter.domain)
 
                 content["membership"] = Membership.JOIN
 
-                profile = self.hs.get_handlers().profile_handler
+                profile = self.profile_handler
                 if not content_specified:
                     content["displayname"] = yield profile.get_displayname(target)
                     content["avatar_url"] = yield profile.get_avatar_url(target)
@@ -264,15 +374,15 @@ class RoomMemberHandler(BaseHandler):
                 if requester.is_guest:
                     content["kind"] = "guest"
 
-                ret = yield self.remote_join(
-                    remote_room_hosts, room_id, target, content
+                ret = yield self._remote_join(
+                    requester, remote_room_hosts, room_id, target, content
                 )
                 defer.returnValue(ret)
 
         elif effective_membership_state == Membership.LEAVE:
             if not is_host_in_room:
                 # perhaps we've been invited
-                inviter = yield self.get_inviter(target.to_string(), room_id)
+                inviter = yield self._get_inviter(target.to_string(), room_id)
                 if not inviter:
                     raise SynapseError(404, "Not a known room")
 
@@ -286,20 +396,10 @@ class RoomMemberHandler(BaseHandler):
                 else:
                     # send the rejection to the inviter's HS.
                     remote_room_hosts = remote_room_hosts + [inviter.domain]
-
-                    try:
-                        ret = yield self.reject_remote_invite(
-                            target.to_string(), room_id, remote_room_hosts
-                        )
-                        defer.returnValue(ret)
-                    except SynapseError as e:
-                        logger.warn("Failed to reject invite: %s", e)
-
-                        yield self.store.locally_reject_invite(
-                            target.to_string(), room_id
-                        )
-
-                        defer.returnValue({})
+                    res = yield self._remote_reject_invite(
+                        requester, remote_room_hosts, room_id, target,
+                    )
+                    defer.returnValue(res)
 
         res = yield self._local_membership_update(
             requester=requester,
@@ -308,7 +408,7 @@ class RoomMemberHandler(BaseHandler):
             membership=effective_membership_state,
             txn_id=txn_id,
             ratelimit=ratelimit,
-            prev_event_ids=latest_event_ids,
+            prev_events_and_hashes=prev_events_and_hashes,
             content=content,
         )
         defer.returnValue(res)
@@ -354,8 +454,9 @@ class RoomMemberHandler(BaseHandler):
         else:
             requester = synapse.types.create_requester(target_user)
 
-        message_handler = self.hs.get_handlers().message_handler
-        prev_event = yield message_handler.deduplicate_state_event(event, context)
+        prev_event = yield self.event_creation_hander.deduplicate_state_event(
+            event, context,
+        )
         if prev_event is not None:
             return
 
@@ -367,7 +468,12 @@ class RoomMemberHandler(BaseHandler):
                     # so don't really fit into the general auth process.
                     raise AuthError(403, "Guest access not allowed")
 
-        yield message_handler.handle_new_client_event(
+        if event.membership not in (Membership.LEAVE, Membership.BAN):
+            is_blocked = yield self.store.is_room_blocked(room_id)
+            if is_blocked:
+                raise SynapseError(403, "This room has been blocked on this server")
+
+        yield self.event_creation_hander.handle_new_client_event(
             requester,
             event,
             context,
@@ -389,12 +495,12 @@ class RoomMemberHandler(BaseHandler):
                 prev_member_event = yield self.store.get_event(prev_member_event_id)
                 newly_joined = prev_member_event.membership != Membership.JOIN
             if newly_joined:
-                yield user_joined_room(self.distributor, target_user, room_id)
+                yield self._user_joined_room(target_user, room_id)
         elif event.membership == Membership.LEAVE:
             if prev_member_event_id:
                 prev_member_event = yield self.store.get_event(prev_member_event_id)
                 if prev_member_event.membership == Membership.JOIN:
-                    user_left_room(self.distributor, target_user, room_id)
+                    yield self._user_left_room(target_user, room_id)
 
     @defer.inlineCallbacks
     def _can_guest_join(self, current_state_ids):
@@ -428,7 +534,7 @@ class RoomMemberHandler(BaseHandler):
         Raises:
             SynapseError if room alias could not be found.
         """
-        directory_handler = self.hs.get_handlers().directory_handler
+        directory_handler = self.directory_handler
         mapping = yield directory_handler.get_association(room_alias)
 
         if not mapping:
@@ -440,7 +546,7 @@ class RoomMemberHandler(BaseHandler):
         defer.returnValue((RoomID.from_string(room_id), servers))
 
     @defer.inlineCallbacks
-    def get_inviter(self, user_id, room_id):
+    def _get_inviter(self, user_id, room_id):
         invite = yield self.store.get_invite_for_user_in_room(
             user_id=user_id,
             room_id=room_id,
@@ -459,6 +565,16 @@ class RoomMemberHandler(BaseHandler):
             requester,
             txn_id
     ):
+        if self.config.block_non_admin_invites:
+            is_requester_admin = yield self.auth.is_server_admin(
+                requester.user,
+            )
+            if not is_requester_admin:
+                raise SynapseError(
+                    403, "Invites have been disabled on this server",
+                    Codes.FORBIDDEN,
+                )
+
         invitee = yield self._lookup_3pid(
             id_server, medium, address
         )
@@ -496,7 +612,7 @@ class RoomMemberHandler(BaseHandler):
             str: the matrix ID of the 3pid, or None if it is not recognized.
         """
         try:
-            data = yield self.hs.get_simple_http_client().get_json(
+            data = yield self.simple_http_client.get_json(
                 "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
                 {
                     "medium": medium,
@@ -507,7 +623,7 @@ class RoomMemberHandler(BaseHandler):
             if "mxid" in data:
                 if "signatures" not in data:
                     raise AuthError(401, "No signatures on 3pid binding")
-                self.verify_any_signature(data, id_server)
+                yield self._verify_any_signature(data, id_server)
                 defer.returnValue(data["mxid"])
 
         except IOError as e:
@@ -515,11 +631,11 @@ class RoomMemberHandler(BaseHandler):
             defer.returnValue(None)
 
     @defer.inlineCallbacks
-    def verify_any_signature(self, data, server_hostname):
+    def _verify_any_signature(self, data, server_hostname):
         if server_hostname not in data["signatures"]:
             raise AuthError(401, "No signature from server %s" % (server_hostname,))
         for key_name, signature in data["signatures"][server_hostname].items():
-            key_data = yield self.hs.get_simple_http_client().get_json(
+            key_data = yield self.simple_http_client.get_json(
                 "%s%s/_matrix/identity/api/v1/pubkey/%s" %
                 (id_server_scheme, server_hostname, key_name,),
             )
@@ -544,7 +660,7 @@ class RoomMemberHandler(BaseHandler):
             user,
             txn_id
     ):
-        room_state = yield self.hs.get_state_handler().get_current_state(room_id)
+        room_state = yield self.state_handler.get_current_state(room_id)
 
         inviter_display_name = ""
         inviter_avatar_url = ""
@@ -575,6 +691,7 @@ class RoomMemberHandler(BaseHandler):
 
         token, public_keys, fallback_public_key, display_name = (
             yield self._ask_id_server_for_third_party_invite(
+                requester=requester,
                 id_server=id_server,
                 medium=medium,
                 address=address,
@@ -589,8 +706,7 @@ class RoomMemberHandler(BaseHandler):
             )
         )
 
-        msg_handler = self.hs.get_handlers().message_handler
-        yield msg_handler.create_and_send_nonmember_event(
+        yield self.event_creation_hander.create_and_send_nonmember_event(
             requester,
             {
                 "type": EventTypes.ThirdPartyInvite,
@@ -612,6 +728,7 @@ class RoomMemberHandler(BaseHandler):
     @defer.inlineCallbacks
     def _ask_id_server_for_third_party_invite(
             self,
+            requester,
             id_server,
             medium,
             address,
@@ -628,6 +745,7 @@ class RoomMemberHandler(BaseHandler):
         Asks an identity server for a third party invite.
 
         Args:
+            requester (Requester)
             id_server (str): hostname + optional port for the identity server.
             medium (str): The literal string "email".
             address (str): The third party address being invited.
@@ -669,24 +787,20 @@ class RoomMemberHandler(BaseHandler):
             "sender_avatar_url": inviter_avatar_url,
         }
 
-        if self.hs.config.invite_3pid_guest:
-            registration_handler = self.hs.get_handlers().registration_handler
-            guest_access_token = yield registration_handler.guest_access_token_for(
+        if self.config.invite_3pid_guest:
+            guest_access_token, guest_user_id = yield self.get_or_register_3pid_guest(
+                requester=requester,
                 medium=medium,
                 address=address,
                 inviter_user_id=inviter_user_id,
             )
 
-            guest_user_info = yield self.hs.get_auth().get_user_by_access_token(
-                guest_access_token
-            )
-
             invite_config.update({
                 "guest_access_token": guest_access_token,
-                "guest_user_id": guest_user_info["user"].to_string(),
+                "guest_user_id": guest_user_id,
             })
 
-        data = yield self.hs.get_simple_http_client().post_urlencoded_get_json(
+        data = yield self.simple_http_client.post_urlencoded_get_json(
             is_url,
             invite_config
         )
@@ -709,25 +823,6 @@ class RoomMemberHandler(BaseHandler):
         defer.returnValue((token, public_keys, fallback_public_key, display_name))
 
     @defer.inlineCallbacks
-    def forget(self, user, room_id):
-        user_id = user.to_string()
-
-        member = yield self.state_handler.get_current_state(
-            room_id=room_id,
-            event_type=EventTypes.Member,
-            state_key=user_id
-        )
-        membership = member.membership if member else None
-
-        if membership is not None and membership != Membership.LEAVE:
-            raise SynapseError(400, "User %s in room %s" % (
-                user_id, room_id
-            ))
-
-        if membership:
-            yield self.store.forget(user_id, room_id)
-
-    @defer.inlineCallbacks
     def _is_host_in_room(self, current_state_ids):
         # Have we just created the room, and is this about to be the very
         # first member event?
@@ -735,10 +830,11 @@ class RoomMemberHandler(BaseHandler):
         if len(current_state_ids) == 1 and create_event_id:
             defer.returnValue(self.hs.is_mine_id(create_event_id))
 
-        for (etype, state_key), event_id in current_state_ids.items():
+        for etype, state_key in current_state_ids:
             if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
                 continue
 
+            event_id = current_state_ids[(etype, state_key)]
             event = yield self.store.get_event(event_id, allow_none=True)
             if not event:
                 continue
@@ -747,3 +843,102 @@ class RoomMemberHandler(BaseHandler):
                 defer.returnValue(True)
 
         defer.returnValue(False)
+
+
+class RoomMemberMasterHandler(RoomMemberHandler):
+    def __init__(self, hs):
+        super(RoomMemberMasterHandler, self).__init__(hs)
+
+        self.distributor = hs.get_distributor()
+        self.distributor.declare("user_joined_room")
+        self.distributor.declare("user_left_room")
+
+    @defer.inlineCallbacks
+    def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+        """Implements RoomMemberHandler._remote_join
+        """
+        # filter ourselves out of remote_room_hosts: do_invite_join ignores it
+        # and if it is the only entry we'd like to return a 404 rather than a
+        # 500.
+
+        remote_room_hosts = [
+            host for host in remote_room_hosts if host != self.hs.hostname
+        ]
+
+        if len(remote_room_hosts) == 0:
+            raise SynapseError(404, "No known servers")
+
+        # We don't do an auth check if we are doing an invite
+        # join dance for now, since we're kinda implicitly checking
+        # that we are allowed to join when we decide whether or not we
+        # need to do the invite/join dance.
+        yield self.federation_handler.do_invite_join(
+            remote_room_hosts,
+            room_id,
+            user.to_string(),
+            content,
+        )
+        yield self._user_joined_room(user, room_id)
+
+    @defer.inlineCallbacks
+    def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
+        """Implements RoomMemberHandler._remote_reject_invite
+        """
+        fed_handler = self.federation_handler
+        try:
+            ret = yield fed_handler.do_remotely_reject_invite(
+                remote_room_hosts,
+                room_id,
+                target.to_string(),
+            )
+            defer.returnValue(ret)
+        except Exception as e:
+            # if we were unable to reject the exception, just mark
+            # it as rejected on our end and plough ahead.
+            #
+            # The 'except' clause is very broad, but we need to
+            # capture everything from DNS failures upwards
+            #
+            logger.warn("Failed to reject invite: %s", e)
+
+            yield self.store.locally_reject_invite(
+                target.to_string(), room_id
+            )
+            defer.returnValue({})
+
+    def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
+        """Implements RoomMemberHandler.get_or_register_3pid_guest
+        """
+        rg = self.registration_handler
+        return rg.get_or_register_3pid_guest(medium, address, inviter_user_id)
+
+    def _user_joined_room(self, target, room_id):
+        """Implements RoomMemberHandler._user_joined_room
+        """
+        return user_joined_room(self.distributor, target, room_id)
+
+    def _user_left_room(self, target, room_id):
+        """Implements RoomMemberHandler._user_left_room
+        """
+        return user_left_room(self.distributor, target, room_id)
+
+    @defer.inlineCallbacks
+    def forget(self, user, room_id):
+        user_id = user.to_string()
+
+        member = yield self.state_handler.get_current_state(
+            room_id=room_id,
+            event_type=EventTypes.Member,
+            state_key=user_id
+        )
+        membership = member.membership if member else None
+
+        if membership is not None and membership not in [
+            Membership.LEAVE, Membership.BAN
+        ]:
+            raise SynapseError(400, "User %s in room %s" % (
+                user_id, room_id
+            ))
+
+        if membership:
+            yield self.store.forget(user_id, room_id)
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
new file mode 100644
index 0000000000..493aec1e48
--- /dev/null
+++ b/synapse/handlers/room_member_worker.py
@@ -0,0 +1,102 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.room_member import RoomMemberHandler
+from synapse.replication.http.membership import (
+    remote_join, remote_reject_invite, get_or_register_3pid_guest,
+    notify_user_membership_change,
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+class RoomMemberWorkerHandler(RoomMemberHandler):
+    @defer.inlineCallbacks
+    def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+        """Implements RoomMemberHandler._remote_join
+        """
+        if len(remote_room_hosts) == 0:
+            raise SynapseError(404, "No known servers")
+
+        ret = yield remote_join(
+            self.simple_http_client,
+            host=self.config.worker_replication_host,
+            port=self.config.worker_replication_http_port,
+            requester=requester,
+            remote_room_hosts=remote_room_hosts,
+            room_id=room_id,
+            user_id=user.to_string(),
+            content=content,
+        )
+
+        yield self._user_joined_room(user, room_id)
+
+        defer.returnValue(ret)
+
+    def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
+        """Implements RoomMemberHandler._remote_reject_invite
+        """
+        return remote_reject_invite(
+            self.simple_http_client,
+            host=self.config.worker_replication_host,
+            port=self.config.worker_replication_http_port,
+            requester=requester,
+            remote_room_hosts=remote_room_hosts,
+            room_id=room_id,
+            user_id=target.to_string(),
+        )
+
+    def _user_joined_room(self, target, room_id):
+        """Implements RoomMemberHandler._user_joined_room
+        """
+        return notify_user_membership_change(
+            self.simple_http_client,
+            host=self.config.worker_replication_host,
+            port=self.config.worker_replication_http_port,
+            user_id=target.to_string(),
+            room_id=room_id,
+            change="joined",
+        )
+
+    def _user_left_room(self, target, room_id):
+        """Implements RoomMemberHandler._user_left_room
+        """
+        return notify_user_membership_change(
+            self.simple_http_client,
+            host=self.config.worker_replication_host,
+            port=self.config.worker_replication_http_port,
+            user_id=target.to_string(),
+            room_id=room_id,
+            change="left",
+        )
+
+    def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
+        """Implements RoomMemberHandler.get_or_register_3pid_guest
+        """
+        return get_or_register_3pid_guest(
+            self.simple_http_client,
+            host=self.config.worker_replication_host,
+            port=self.config.worker_replication_http_port,
+            requester=requester,
+            medium=medium,
+            address=address,
+            inviter_user_id=inviter_user_id,
+        )
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index df75d70fac..9772ed1a0e 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -61,7 +61,7 @@ class SearchHandler(BaseHandler):
                 assert batch_group is not None
                 assert batch_group_key is not None
                 assert batch_token is not None
-            except:
+            except Exception:
                 raise SynapseError(400, "Invalid batch")
 
         try:
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
new file mode 100644
index 0000000000..e057ae54c9
--- /dev/null
+++ b/synapse/handlers/set_password.py
@@ -0,0 +1,56 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.errors import Codes, StoreError, SynapseError
+from ._base import BaseHandler
+
+logger = logging.getLogger(__name__)
+
+
+class SetPasswordHandler(BaseHandler):
+    """Handler which deals with changing user account passwords"""
+    def __init__(self, hs):
+        super(SetPasswordHandler, self).__init__(hs)
+        self._auth_handler = hs.get_auth_handler()
+        self._device_handler = hs.get_device_handler()
+
+    @defer.inlineCallbacks
+    def set_password(self, user_id, newpassword, requester=None):
+        password_hash = yield self._auth_handler.hash(newpassword)
+
+        except_device_id = requester.device_id if requester else None
+        except_access_token_id = requester.access_token_id if requester else None
+
+        try:
+            yield self.store.user_set_password_hash(user_id, password_hash)
+        except StoreError as e:
+            if e.code == 404:
+                raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
+            raise e
+
+        # we want to log out all of the user's other sessions. First delete
+        # all his other devices.
+        yield self._device_handler.delete_all_devices_for_user(
+            user_id, except_device_id=except_device_id,
+        )
+
+        # and now delete any access tokens which weren't associated with
+        # devices (or were associated with this device).
+        yield self._auth_handler.delete_access_tokens_for_user(
+            user_id, except_token_id=except_access_token_id,
+        )
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index d7dcd1ce5b..b52e4c2aff 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -20,6 +20,7 @@ from synapse.util.metrics import Measure, measure_func
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.push.clientformat import format_push_rules_for_user
 from synapse.visibility import filter_events_for_client
+from synapse.types import RoomStreamToken
 
 from twisted.internet import defer
 
@@ -51,6 +52,7 @@ class TimelineBatch(collections.namedtuple("TimelineBatch", [
         to tell if room needs to be part of the sync result.
         """
         return bool(self.events)
+    __bool__ = __nonzero__  # python3
 
 
 class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
@@ -75,6 +77,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
             # nb the notification count does not, er, count: if there's nothing
             # else in the result, we don't need to send it.
         )
+    __bool__ = __nonzero__  # python3
 
 
 class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
@@ -94,6 +97,7 @@ class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
             or self.state
             or self.account_data
         )
+    __bool__ = __nonzero__  # python3
 
 
 class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
@@ -105,6 +109,30 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
     def __nonzero__(self):
         """Invited rooms should always be reported to the client"""
         return True
+    __bool__ = __nonzero__  # python3
+
+
+class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [
+    "join",
+    "invite",
+    "leave",
+])):
+    __slots__ = []
+
+    def __nonzero__(self):
+        return bool(self.join or self.invite or self.leave)
+    __bool__ = __nonzero__  # python3
+
+
+class DeviceLists(collections.namedtuple("DeviceLists", [
+    "changed",   # list of user_ids whose devices may have changed
+    "left",      # list of user_ids whose devices we no longer track
+])):
+    __slots__ = []
+
+    def __nonzero__(self):
+        return bool(self.changed or self.left)
+    __bool__ = __nonzero__  # python3
 
 
 class SyncResult(collections.namedtuple("SyncResult", [
@@ -116,6 +144,9 @@ class SyncResult(collections.namedtuple("SyncResult", [
     "archived",  # ArchivedSyncResult for each archived room.
     "to_device",  # List of direct messages for the device.
     "device_lists",  # List of user_ids whose devices have chanegd
+    "device_one_time_keys_count",  # Dict of algorithm to count for one time keys
+                                   # for this device
+    "groups",
 ])):
     __slots__ = []
 
@@ -131,8 +162,10 @@ class SyncResult(collections.namedtuple("SyncResult", [
             self.archived or
             self.account_data or
             self.to_device or
-            self.device_lists
+            self.device_lists or
+            self.groups
         )
+    __bool__ = __nonzero__  # python3
 
 
 class SyncHandler(object):
@@ -143,7 +176,7 @@ class SyncHandler(object):
         self.presence_handler = hs.get_presence_handler()
         self.event_sources = hs.get_event_sources()
         self.clock = hs.get_clock()
-        self.response_cache = ResponseCache(hs)
+        self.response_cache = ResponseCache(hs, "sync")
         self.state = hs.get_state_handler()
 
     def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
@@ -154,15 +187,11 @@ class SyncHandler(object):
         Returns:
             A Deferred SyncResult.
         """
-        result = self.response_cache.get(sync_config.request_key)
-        if not result:
-            result = self.response_cache.set(
-                sync_config.request_key,
-                self._wait_for_sync_for_user(
-                    sync_config, since_token, timeout, full_state
-                )
-            )
-        return result
+        return self.response_cache.wrap(
+            sync_config.request_key,
+            self._wait_for_sync_for_user,
+            sync_config, since_token, timeout, full_state,
+        )
 
     @defer.inlineCallbacks
     def _wait_for_sync_for_user(self, sync_config, since_token, timeout,
@@ -209,10 +238,10 @@ class SyncHandler(object):
         defer.returnValue(rules)
 
     @defer.inlineCallbacks
-    def ephemeral_by_room(self, sync_config, now_token, since_token=None):
+    def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
         """Get the ephemeral events for each room the user is in
         Args:
-            sync_config (SyncConfig): The flags, filters and user for the sync.
+            sync_result_builder(SyncResultBuilder)
             now_token (StreamToken): Where the server is currently up to.
             since_token (StreamToken): Where the server was when the client
                 last synced.
@@ -222,11 +251,12 @@ class SyncHandler(object):
             typing events for that room.
         """
 
+        sync_config = sync_result_builder.sync_config
+
         with Measure(self.clock, "ephemeral_by_room"):
             typing_key = since_token.typing_key if since_token else "0"
 
-            rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
-            room_ids = [room.room_id for room in rooms]
+            room_ids = sync_result_builder.joined_room_ids
 
             typing_source = self.event_sources.sources["typing"]
             typing, typing_key = yield typing_source.get_new_events(
@@ -288,10 +318,20 @@ class SyncHandler(object):
 
             if recents:
                 recents = sync_config.filter_collection.filter_room_timeline(recents)
+
+                # We check if there are any state events, if there are then we pass
+                # all current state events to the filter_events function. This is to
+                # ensure that we always include current state in the timeline
+                current_state_ids = frozenset()
+                if any(e.is_state() for e in recents):
+                    current_state_ids = yield self.state.get_current_state_ids(room_id)
+                    current_state_ids = frozenset(current_state_ids.itervalues())
+
                 recents = yield filter_events_for_client(
                     self.store,
                     sync_config.user.to_string(),
                     recents,
+                    always_include_ids=current_state_ids,
                 )
             else:
                 recents = []
@@ -323,10 +363,20 @@ class SyncHandler(object):
                 loaded_recents = sync_config.filter_collection.filter_room_timeline(
                     events
                 )
+
+                # We check if there are any state events, if there are then we pass
+                # all current state events to the filter_events function. This is to
+                # ensure that we always include current state in the timeline
+                current_state_ids = frozenset()
+                if any(e.is_state() for e in loaded_recents):
+                    current_state_ids = yield self.state.get_current_state_ids(room_id)
+                    current_state_ids = frozenset(current_state_ids.itervalues())
+
                 loaded_recents = yield filter_events_for_client(
                     self.store,
                     sync_config.user.to_string(),
                     loaded_recents,
+                    always_include_ids=current_state_ids,
                 )
                 loaded_recents.extend(recents)
                 recents = loaded_recents
@@ -520,10 +570,22 @@ class SyncHandler(object):
         # Always use the `now_token` in `SyncResultBuilder`
         now_token = yield self.event_sources.get_current_token()
 
+        user_id = sync_config.user.to_string()
+        app_service = self.store.get_app_service_by_user_id(user_id)
+        if app_service:
+            # We no longer support AS users using /sync directly.
+            # See https://github.com/matrix-org/matrix-doc/issues/1144
+            raise NotImplementedError()
+        else:
+            joined_room_ids = yield self.get_rooms_for_user_at(
+                user_id, now_token.room_stream_id,
+            )
+
         sync_result_builder = SyncResultBuilder(
             sync_config, full_state,
             since_token=since_token,
             now_token=now_token,
+            joined_room_ids=joined_room_ids,
         )
 
         account_data_by_room = yield self._generate_sync_entry_for_account_data(
@@ -533,7 +595,8 @@ class SyncHandler(object):
         res = yield self._generate_sync_entry_for_rooms(
             sync_result_builder, account_data_by_room
         )
-        newly_joined_rooms, newly_joined_users = res
+        newly_joined_rooms, newly_joined_users, _, _ = res
+        _, _, newly_left_rooms, newly_left_users = res
 
         block_all_presence_data = (
             since_token is None and
@@ -547,9 +610,22 @@ class SyncHandler(object):
         yield self._generate_sync_entry_for_to_device(sync_result_builder)
 
         device_lists = yield self._generate_sync_entry_for_device_list(
-            sync_result_builder
+            sync_result_builder,
+            newly_joined_rooms=newly_joined_rooms,
+            newly_joined_users=newly_joined_users,
+            newly_left_rooms=newly_left_rooms,
+            newly_left_users=newly_left_users,
         )
 
+        device_id = sync_config.device_id
+        one_time_key_counts = {}
+        if device_id:
+            one_time_key_counts = yield self.store.count_e2e_one_time_keys(
+                user_id, device_id
+            )
+
+        yield self._generate_sync_entry_for_groups(sync_result_builder)
+
         defer.returnValue(SyncResult(
             presence=sync_result_builder.presence,
             account_data=sync_result_builder.account_data,
@@ -558,31 +634,103 @@ class SyncHandler(object):
             archived=sync_result_builder.archived,
             to_device=sync_result_builder.to_device,
             device_lists=device_lists,
+            groups=sync_result_builder.groups,
+            device_one_time_keys_count=one_time_key_counts,
             next_batch=sync_result_builder.now_token,
         ))
 
+    @measure_func("_generate_sync_entry_for_groups")
+    @defer.inlineCallbacks
+    def _generate_sync_entry_for_groups(self, sync_result_builder):
+        user_id = sync_result_builder.sync_config.user.to_string()
+        since_token = sync_result_builder.since_token
+        now_token = sync_result_builder.now_token
+
+        if since_token and since_token.groups_key:
+            results = yield self.store.get_groups_changes_for_user(
+                user_id, since_token.groups_key, now_token.groups_key,
+            )
+        else:
+            results = yield self.store.get_all_groups_for_user(
+                user_id, now_token.groups_key,
+            )
+
+        invited = {}
+        joined = {}
+        left = {}
+        for result in results:
+            membership = result["membership"]
+            group_id = result["group_id"]
+            gtype = result["type"]
+            content = result["content"]
+
+            if membership == "join":
+                if gtype == "membership":
+                    # TODO: Add profile
+                    content.pop("membership", None)
+                    joined[group_id] = content["content"]
+                else:
+                    joined.setdefault(group_id, {})[gtype] = content
+            elif membership == "invite":
+                if gtype == "membership":
+                    content.pop("membership", None)
+                    invited[group_id] = content["content"]
+            else:
+                if gtype == "membership":
+                    left[group_id] = content["content"]
+
+        sync_result_builder.groups = GroupsSyncResult(
+            join=joined,
+            invite=invited,
+            leave=left,
+        )
+
     @measure_func("_generate_sync_entry_for_device_list")
     @defer.inlineCallbacks
-    def _generate_sync_entry_for_device_list(self, sync_result_builder):
+    def _generate_sync_entry_for_device_list(self, sync_result_builder,
+                                             newly_joined_rooms, newly_joined_users,
+                                             newly_left_rooms, newly_left_users):
         user_id = sync_result_builder.sync_config.user.to_string()
         since_token = sync_result_builder.since_token
 
         if since_token and since_token.device_list_key:
-            rooms = yield self.store.get_rooms_for_user(user_id)
-            room_ids = set(r.room_id for r in rooms)
-
-            user_ids_changed = set()
             changed = yield self.store.get_user_whose_devices_changed(
                 since_token.device_list_key
             )
-            for other_user_id in changed:
-                other_rooms = yield self.store.get_rooms_for_user(other_user_id)
-                if room_ids.intersection(e.room_id for e in other_rooms):
-                    user_ids_changed.add(other_user_id)
 
-            defer.returnValue(user_ids_changed)
+            # TODO: Be more clever than this, i.e. remove users who we already
+            # share a room with?
+            for room_id in newly_joined_rooms:
+                joined_users = yield self.state.get_current_user_in_room(room_id)
+                newly_joined_users.update(joined_users)
+
+            for room_id in newly_left_rooms:
+                left_users = yield self.state.get_current_user_in_room(room_id)
+                newly_left_users.update(left_users)
+
+            # TODO: Check that these users are actually new, i.e. either they
+            # weren't in the previous sync *or* they left and rejoined.
+            changed.update(newly_joined_users)
+
+            if not changed and not newly_left_users:
+                defer.returnValue(DeviceLists(
+                    changed=[],
+                    left=newly_left_users,
+                ))
+
+            users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+                user_id
+            )
+
+            defer.returnValue(DeviceLists(
+                changed=users_who_share_room & changed,
+                left=set(newly_left_users) - users_who_share_room,
+            ))
         else:
-            defer.returnValue([])
+            defer.returnValue(DeviceLists(
+                changed=[],
+                left=[],
+            ))
 
     @defer.inlineCallbacks
     def _generate_sync_entry_for_to_device(self, sync_result_builder):
@@ -609,14 +757,14 @@ class SyncHandler(object):
             deleted = yield self.store.delete_messages_for_device(
                 user_id, device_id, since_stream_id
             )
-            logger.info("Deleted %d to-device messages up to %d",
-                        deleted, since_stream_id)
+            logger.debug("Deleted %d to-device messages up to %d",
+                         deleted, since_stream_id)
 
             messages, stream_id = yield self.store.get_new_messages_for_device(
                 user_id, device_id, since_stream_id, now_token.to_device_key
             )
 
-            logger.info(
+            logger.debug(
                 "Returning %d to-device messages between %d and %d (current token: %d)",
                 len(messages), since_stream_id, stream_id, now_token.to_device_key
             )
@@ -721,14 +869,14 @@ class SyncHandler(object):
             extra_users_ids.update(users)
         extra_users_ids.discard(user.to_string())
 
-        states = yield self.presence_handler.get_states(
-            extra_users_ids,
-            as_event=True,
-        )
-        presence.extend(states)
+        if extra_users_ids:
+            states = yield self.presence_handler.get_states(
+                extra_users_ids,
+            )
+            presence.extend(states)
 
-        # Deduplicate the presence entries so that there's at most one per user
-        presence = {p["content"]["user_id"]: p for p in presence}.values()
+            # Deduplicate the presence entries so that there's at most one per user
+            presence = {p.user_id: p for p in presence}.values()
 
         presence = sync_config.filter_collection.filter_presence(
             presence
@@ -746,8 +894,8 @@ class SyncHandler(object):
             account_data_by_room(dict): Dictionary of per room account data
 
         Returns:
-            Deferred(tuple): Returns a 2-tuple of
-            `(newly_joined_rooms, newly_joined_users)`
+            Deferred(tuple): Returns a 4-tuple of
+            `(newly_joined_rooms, newly_joined_users, newly_left_rooms, newly_left_users)`
         """
         user_id = sync_result_builder.sync_config.user.to_string()
         block_all_room_ephemeral = (
@@ -759,12 +907,27 @@ class SyncHandler(object):
             ephemeral_by_room = {}
         else:
             now_token, ephemeral_by_room = yield self.ephemeral_by_room(
-                sync_result_builder.sync_config,
+                sync_result_builder,
                 now_token=sync_result_builder.now_token,
                 since_token=sync_result_builder.since_token,
             )
             sync_result_builder.now_token = now_token
 
+        # We check up front if anything has changed, if it hasn't then there is
+        # no point in going futher.
+        since_token = sync_result_builder.since_token
+        if not sync_result_builder.full_state:
+            if since_token and not ephemeral_by_room and not account_data_by_room:
+                have_changed = yield self._have_rooms_changed(sync_result_builder)
+                if not have_changed:
+                    tags_by_room = yield self.store.get_updated_tags(
+                        user_id,
+                        since_token.account_data_key,
+                    )
+                    if not tags_by_room:
+                        logger.debug("no-oping sync")
+                        defer.returnValue(([], [], [], []))
+
         ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
             "m.ignored_user_list", user_id=user_id,
         )
@@ -774,17 +937,17 @@ class SyncHandler(object):
         else:
             ignored_users = frozenset()
 
-        if sync_result_builder.since_token:
+        if since_token:
             res = yield self._get_rooms_changed(sync_result_builder, ignored_users)
-            room_entries, invited, newly_joined_rooms = res
+            room_entries, invited, newly_joined_rooms, newly_left_rooms = res
 
             tags_by_room = yield self.store.get_updated_tags(
-                user_id,
-                sync_result_builder.since_token.account_data_key,
+                user_id, since_token.account_data_key,
             )
         else:
             res = yield self._get_all_rooms(sync_result_builder, ignored_users)
             room_entries, invited, newly_joined_rooms = res
+            newly_left_rooms = []
 
             tags_by_room = yield self.store.get_tags_for_user(user_id)
 
@@ -805,17 +968,55 @@ class SyncHandler(object):
 
         # Now we want to get any newly joined users
         newly_joined_users = set()
-        if sync_result_builder.since_token:
+        newly_left_users = set()
+        if since_token:
             for joined_sync in sync_result_builder.joined:
                 it = itertools.chain(
-                    joined_sync.timeline.events, joined_sync.state.values()
+                    joined_sync.timeline.events, joined_sync.state.itervalues()
                 )
                 for event in it:
                     if event.type == EventTypes.Member:
                         if event.membership == Membership.JOIN:
                             newly_joined_users.add(event.state_key)
+                        else:
+                            prev_content = event.unsigned.get("prev_content", {})
+                            prev_membership = prev_content.get("membership", None)
+                            if prev_membership == Membership.JOIN:
+                                newly_left_users.add(event.state_key)
+
+        newly_left_users -= newly_joined_users
+
+        defer.returnValue((
+            newly_joined_rooms,
+            newly_joined_users,
+            newly_left_rooms,
+            newly_left_users,
+        ))
+
+    @defer.inlineCallbacks
+    def _have_rooms_changed(self, sync_result_builder):
+        """Returns whether there may be any new events that should be sent down
+        the sync. Returns True if there are.
+        """
+        user_id = sync_result_builder.sync_config.user.to_string()
+        since_token = sync_result_builder.since_token
+        now_token = sync_result_builder.now_token
 
-        defer.returnValue((newly_joined_rooms, newly_joined_users))
+        assert since_token
+
+        # Get a list of membership change events that have happened.
+        rooms_changed = yield self.store.get_membership_changes_for_user(
+            user_id, since_token.room_key, now_token.room_key
+        )
+
+        if rooms_changed:
+            defer.returnValue(True)
+
+        stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
+        for room_id in sync_result_builder.joined_room_ids:
+            if self.store.has_room_changed_since(room_id, stream_id):
+                defer.returnValue(True)
+        defer.returnValue(False)
 
     @defer.inlineCallbacks
     def _get_rooms_changed(self, sync_result_builder, ignored_users):
@@ -836,14 +1037,6 @@ class SyncHandler(object):
 
         assert since_token
 
-        app_service = self.store.get_app_service_by_user_id(user_id)
-        if app_service:
-            rooms = yield self.store.get_app_service_rooms(app_service)
-            joined_room_ids = set(r.room_id for r in rooms)
-        else:
-            rooms = yield self.store.get_rooms_for_user(user_id)
-            joined_room_ids = set(r.room_id for r in rooms)
-
         # Get a list of membership change events that have happened.
         rooms_changed = yield self.store.get_membership_changes_for_user(
             user_id, since_token.room_key, now_token.room_key
@@ -854,16 +1047,29 @@ class SyncHandler(object):
             mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
 
         newly_joined_rooms = []
+        newly_left_rooms = []
         room_entries = []
         invited = []
-        for room_id, events in mem_change_events_by_room_id.items():
+        for room_id, events in mem_change_events_by_room_id.iteritems():
             non_joins = [e for e in events if e.membership != Membership.JOIN]
             has_join = len(non_joins) != len(events)
 
             # We want to figure out if we joined the room at some point since
             # the last sync (even if we have since left). This is to make sure
             # we do send down the room, and with full state, where necessary
-            if room_id in joined_room_ids or has_join:
+
+            old_state_ids = None
+            if room_id in sync_result_builder.joined_room_ids and non_joins:
+                # Always include if the user (re)joined the room, especially
+                # important so that device list changes are calculated correctly.
+                # If there are non join member events, but we are still in the room,
+                # then the user must have left and joined
+                newly_joined_rooms.append(room_id)
+
+                # User is in the room so we don't need to do the invite/leave checks
+                continue
+
+            if room_id in sync_result_builder.joined_room_ids or has_join:
                 old_state_ids = yield self.get_state_at(room_id, since_token)
                 old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
                 old_mem_ev = None
@@ -874,12 +1080,33 @@ class SyncHandler(object):
                 if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
                     newly_joined_rooms.append(room_id)
 
-                if room_id in joined_room_ids:
-                    continue
+            # If user is in the room then we don't need to do the invite/leave checks
+            if room_id in sync_result_builder.joined_room_ids:
+                continue
 
             if not non_joins:
                 continue
 
+            # Check if we have left the room. This can either be because we were
+            # joined before *or* that we since joined and then left.
+            if events[-1].membership != Membership.JOIN:
+                if has_join:
+                    newly_left_rooms.append(room_id)
+                else:
+                    if not old_state_ids:
+                        old_state_ids = yield self.get_state_at(room_id, since_token)
+                        old_mem_ev_id = old_state_ids.get(
+                            (EventTypes.Member, user_id),
+                            None,
+                        )
+                        old_mem_ev = None
+                        if old_mem_ev_id:
+                            old_mem_ev = yield self.store.get_event(
+                                old_mem_ev_id, allow_none=True
+                            )
+                    if old_mem_ev and old_mem_ev.membership == Membership.JOIN:
+                        newly_left_rooms.append(room_id)
+
             # Only bother if we're still currently invited
             should_invite = non_joins[-1].membership == Membership.INVITE
             if should_invite:
@@ -921,7 +1148,7 @@ class SyncHandler(object):
 
         # Get all events for rooms we're currently joined to.
         room_to_events = yield self.store.get_room_events_stream_for_rooms(
-            room_ids=joined_room_ids,
+            room_ids=sync_result_builder.joined_room_ids,
             from_key=since_token.room_key,
             to_key=now_token.room_key,
             limit=timeline_limit + 1,
@@ -929,7 +1156,7 @@ class SyncHandler(object):
 
         # We loop through all room ids, even if there are no new events, in case
         # there are non room events taht we need to notify about.
-        for room_id in joined_room_ids:
+        for room_id in sync_result_builder.joined_room_ids:
             room_entry = room_to_events.get(room_id, None)
 
             if room_entry:
@@ -957,7 +1184,7 @@ class SyncHandler(object):
                     upto_token=since_token,
                 ))
 
-        defer.returnValue((room_entries, invited, newly_joined_rooms))
+        defer.returnValue((room_entries, invited, newly_joined_rooms, newly_left_rooms))
 
     @defer.inlineCallbacks
     def _get_all_rooms(self, sync_result_builder, ignored_users):
@@ -1137,6 +1364,54 @@ class SyncHandler(object):
         else:
             raise Exception("Unrecognized rtype: %r", room_builder.rtype)
 
+    @defer.inlineCallbacks
+    def get_rooms_for_user_at(self, user_id, stream_ordering):
+        """Get set of joined rooms for a user at the given stream ordering.
+
+        The stream ordering *must* be recent, otherwise this may throw an
+        exception if older than a month. (This function is called with the
+        current token, which should be perfectly fine).
+
+        Args:
+            user_id (str)
+            stream_ordering (int)
+
+        ReturnValue:
+            Deferred[frozenset[str]]: Set of room_ids the user is in at given
+            stream_ordering.
+        """
+        joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(
+            user_id,
+        )
+
+        joined_room_ids = set()
+
+        # We need to check that the stream ordering of the join for each room
+        # is before the stream_ordering asked for. This might not be the case
+        # if the user joins a room between us getting the current token and
+        # calling `get_rooms_for_user_with_stream_ordering`.
+        # If the membership's stream ordering is after the given stream
+        # ordering, we need to go and work out if the user was in the room
+        # before.
+        for room_id, membership_stream_ordering in joined_rooms:
+            if membership_stream_ordering <= stream_ordering:
+                joined_room_ids.add(room_id)
+                continue
+
+            logger.info("User joined room after current token: %s", room_id)
+
+            extrems = yield self.store.get_forward_extremeties_for_room(
+                room_id, stream_ordering,
+            )
+            users_in_room = yield self.state.get_current_user_in_room(
+                room_id, extrems,
+            )
+            if user_id in users_in_room:
+                joined_room_ids.add(room_id)
+
+        joined_room_ids = frozenset(joined_room_ids)
+        defer.returnValue(joined_room_ids)
+
 
 def _action_has_highlight(actions):
     for action in actions:
@@ -1186,7 +1461,8 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
 
 class SyncResultBuilder(object):
     "Used to help build up a new SyncResult for a user"
-    def __init__(self, sync_config, full_state, since_token, now_token):
+    def __init__(self, sync_config, full_state, since_token, now_token,
+                 joined_room_ids):
         """
         Args:
             sync_config(SyncConfig)
@@ -1198,6 +1474,7 @@ class SyncResultBuilder(object):
         self.full_state = full_state
         self.since_token = since_token
         self.now_token = now_token
+        self.joined_room_ids = joined_room_ids
 
         self.presence = []
         self.account_data = []
@@ -1205,6 +1482,8 @@ class SyncResultBuilder(object):
         self.invited = []
         self.archived = []
         self.device = []
+        self.groups = None
+        self.to_device = []
 
 
 class RoomSyncResultBuilder(object):
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 0eea7f8f9c..5d9736e88f 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -16,7 +16,7 @@
 from twisted.internet import defer
 
 from synapse.api.errors import SynapseError, AuthError
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import run_in_background
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
 from synapse.types import UserID, get_domain_from_id
@@ -24,7 +24,6 @@ from synapse.types import UserID, get_domain_from_id
 import logging
 
 from collections import namedtuple
-import ujson as json
 
 logger = logging.getLogger(__name__)
 
@@ -57,7 +56,7 @@ class TypingHandler(object):
 
         self.federation = hs.get_federation_sender()
 
-        hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu)
+        hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
 
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
@@ -90,7 +89,7 @@ class TypingHandler(object):
             until = self._member_typing_until.get(member, None)
             if not until or until <= now:
                 logger.info("Timing out typing for: %s", member.user_id)
-                preserve_fn(self._stopped_typing)(member)
+                self._stopped_typing(member)
                 continue
 
             # Check if we need to resend a keep alive over federation for this
@@ -98,7 +97,8 @@ class TypingHandler(object):
             if self.hs.is_mine_id(member.user_id):
                 last_fed_poke = self._member_last_federation_poke.get(member, None)
                 if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
-                    preserve_fn(self._push_remote)(
+                    run_in_background(
+                        self._push_remote,
                         member=member,
                         typing=True
                     )
@@ -148,7 +148,7 @@ class TypingHandler(object):
             # No point sending another notification
             defer.returnValue(None)
 
-        yield self._push_update(
+        self._push_update(
             member=member,
             typing=True,
         )
@@ -172,7 +172,7 @@ class TypingHandler(object):
 
         member = RoomMember(room_id=room_id, user_id=target_user_id)
 
-        yield self._stopped_typing(member)
+        self._stopped_typing(member)
 
     @defer.inlineCallbacks
     def user_left_room(self, user, room_id):
@@ -181,7 +181,6 @@ class TypingHandler(object):
             member = RoomMember(room_id=room_id, user_id=user_id)
             yield self._stopped_typing(member)
 
-    @defer.inlineCallbacks
     def _stopped_typing(self, member):
         if member.user_id not in self._room_typing.get(member.room_id, set()):
             # No point
@@ -190,16 +189,15 @@ class TypingHandler(object):
         self._member_typing_until.pop(member, None)
         self._member_last_federation_poke.pop(member, None)
 
-        yield self._push_update(
+        self._push_update(
             member=member,
             typing=False,
         )
 
-    @defer.inlineCallbacks
     def _push_update(self, member, typing):
         if self.hs.is_mine_id(member.user_id):
             # Only send updates for changes to our own users.
-            yield self._push_remote(member, typing)
+            run_in_background(self._push_remote, member, typing)
 
         self._push_update_local(
             member=member,
@@ -208,28 +206,31 @@ class TypingHandler(object):
 
     @defer.inlineCallbacks
     def _push_remote(self, member, typing):
-        users = yield self.state.get_current_user_in_room(member.room_id)
-        self._member_last_federation_poke[member] = self.clock.time_msec()
+        try:
+            users = yield self.state.get_current_user_in_room(member.room_id)
+            self._member_last_federation_poke[member] = self.clock.time_msec()
 
-        now = self.clock.time_msec()
-        self.wheel_timer.insert(
-            now=now,
-            obj=member,
-            then=now + FEDERATION_PING_INTERVAL,
-        )
+            now = self.clock.time_msec()
+            self.wheel_timer.insert(
+                now=now,
+                obj=member,
+                then=now + FEDERATION_PING_INTERVAL,
+            )
 
-        for domain in set(get_domain_from_id(u) for u in users):
-            if domain != self.server_name:
-                self.federation.send_edu(
-                    destination=domain,
-                    edu_type="m.typing",
-                    content={
-                        "room_id": member.room_id,
-                        "user_id": member.user_id,
-                        "typing": typing,
-                    },
-                    key=member,
-                )
+            for domain in set(get_domain_from_id(u) for u in users):
+                if domain != self.server_name:
+                    self.federation.send_edu(
+                        destination=domain,
+                        edu_type="m.typing",
+                        content={
+                            "room_id": member.room_id,
+                            "user_id": member.user_id,
+                            "typing": typing,
+                        },
+                        key=member,
+                    )
+        except Exception:
+            logger.exception("Error pushing typing notif to remotes")
 
     @defer.inlineCallbacks
     def _recv_edu(self, origin, content):
@@ -288,11 +289,13 @@ class TypingHandler(object):
         for room_id, serial in self._room_serials.items():
             if last_id < serial and serial <= current_id:
                 typing = self._room_typing[room_id]
-                typing_bytes = json.dumps(list(typing), ensure_ascii=False)
-                rows.append((serial, room_id, typing_bytes))
+                rows.append((serial, room_id, list(typing)))
         rows.sort()
         return rows
 
+    def get_current_token(self):
+        return self._latest_room_serial
+
 
 class TypingNotificationEventSource(object):
     def __init__(self, hs):
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
new file mode 100644
index 0000000000..714f0195c8
--- /dev/null
+++ b/synapse/handlers/user_directory.py
@@ -0,0 +1,681 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.storage.roommember import ProfileInfo
+from synapse.util.metrics import Measure
+from synapse.util.async import sleep
+from synapse.types import get_localpart_from_id
+
+
+logger = logging.getLogger(__name__)
+
+
+class UserDirectoryHandler(object):
+    """Handles querying of and keeping updated the user_directory.
+
+    N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY
+
+    The user directory is filled with users who this server can see are joined to a
+    world_readable or publically joinable room. We keep a database table up to date
+    by streaming changes of the current state and recalculating whether users should
+    be in the directory or not when necessary.
+
+    For each user in the directory we also store a room_id which is public and that the
+    user is joined to. This allows us to ignore history_visibility and join_rules changes
+    for that user in all other public rooms, as we know they'll still be in at least
+    one public room.
+    """
+
+    INITIAL_ROOM_SLEEP_MS = 50
+    INITIAL_ROOM_SLEEP_COUNT = 100
+    INITIAL_ROOM_BATCH_SIZE = 100
+    INITIAL_USER_SLEEP_MS = 10
+
+    def __init__(self, hs):
+        self.store = hs.get_datastore()
+        self.state = hs.get_state_handler()
+        self.server_name = hs.hostname
+        self.clock = hs.get_clock()
+        self.notifier = hs.get_notifier()
+        self.is_mine_id = hs.is_mine_id
+        self.update_user_directory = hs.config.update_user_directory
+        self.search_all_users = hs.config.user_directory_search_all_users
+
+        # When start up for the first time we need to populate the user_directory.
+        # This is a set of user_id's we've inserted already
+        self.initially_handled_users = set()
+        self.initially_handled_users_in_public = set()
+
+        self.initially_handled_users_share = set()
+        self.initially_handled_users_share_private_room = set()
+
+        # The current position in the current_state_delta stream
+        self.pos = None
+
+        # Guard to ensure we only process deltas one at a time
+        self._is_processing = False
+
+        if self.update_user_directory:
+            self.notifier.add_replication_callback(self.notify_new_event)
+
+            # We kick this off so that we don't have to wait for a change before
+            # we start populating the user directory
+            self.clock.call_later(0, self.notify_new_event)
+
+    def search_users(self, user_id, search_term, limit):
+        """Searches for users in directory
+
+        Returns:
+            dict of the form::
+
+                {
+                    "limited": <bool>,  # whether there were more results or not
+                    "results": [  # Ordered by best match first
+                        {
+                            "user_id": <user_id>,
+                            "display_name": <display_name>,
+                            "avatar_url": <avatar_url>
+                        }
+                    ]
+                }
+        """
+        return self.store.search_user_dir(user_id, search_term, limit)
+
+    @defer.inlineCallbacks
+    def notify_new_event(self):
+        """Called when there may be more deltas to process
+        """
+        if not self.update_user_directory:
+            return
+
+        if self._is_processing:
+            return
+
+        self._is_processing = True
+        try:
+            yield self._unsafe_process()
+        finally:
+            self._is_processing = False
+
+    @defer.inlineCallbacks
+    def handle_local_profile_change(self, user_id, profile):
+        """Called to update index of our local user profiles when they change
+        irrespective of any rooms the user may be in.
+        """
+        yield self.store.update_profile_in_user_dir(
+            user_id, profile.display_name, profile.avatar_url, None,
+        )
+
+    @defer.inlineCallbacks
+    def _unsafe_process(self):
+        # If self.pos is None then means we haven't fetched it from DB
+        if self.pos is None:
+            self.pos = yield self.store.get_user_directory_stream_pos()
+
+        # If still None then we need to do the initial fill of directory
+        if self.pos is None:
+            yield self._do_initial_spam()
+            self.pos = yield self.store.get_user_directory_stream_pos()
+
+        # Loop round handling deltas until we're up to date
+        while True:
+            with Measure(self.clock, "user_dir_delta"):
+                deltas = yield self.store.get_current_state_deltas(self.pos)
+                if not deltas:
+                    return
+
+                logger.info("Handling %d state deltas", len(deltas))
+                yield self._handle_deltas(deltas)
+
+                self.pos = deltas[-1]["stream_id"]
+                yield self.store.update_user_directory_stream_pos(self.pos)
+
+    @defer.inlineCallbacks
+    def _do_initial_spam(self):
+        """Populates the user_directory from the current state of the DB, used
+        when synapse first starts with user_directory support
+        """
+        new_pos = yield self.store.get_max_stream_id_in_current_state_deltas()
+
+        # Delete any existing entries just in case there are any
+        yield self.store.delete_all_from_user_dir()
+
+        # We process by going through each existing room at a time.
+        room_ids = yield self.store.get_all_rooms()
+
+        logger.info("Doing initial update of user directory. %d rooms", len(room_ids))
+        num_processed_rooms = 0
+
+        for room_id in room_ids:
+            logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids))
+            yield self._handle_initial_room(room_id)
+            num_processed_rooms += 1
+            yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
+
+        logger.info("Processed all rooms.")
+
+        if self.search_all_users:
+            num_processed_users = 0
+            user_ids = yield self.store.get_all_local_users()
+            logger.info("Doing initial update of user directory. %d users", len(user_ids))
+            for user_id in user_ids:
+                # We add profiles for all users even if they don't match the
+                # include pattern, just in case we want to change it in future
+                logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids))
+                yield self._handle_local_user(user_id)
+                num_processed_users += 1
+                yield sleep(self.INITIAL_USER_SLEEP_MS / 1000.)
+
+            logger.info("Processed all users")
+
+        self.initially_handled_users = None
+        self.initially_handled_users_in_public = None
+        self.initially_handled_users_share = None
+        self.initially_handled_users_share_private_room = None
+
+        yield self.store.update_user_directory_stream_pos(new_pos)
+
+    @defer.inlineCallbacks
+    def _handle_initial_room(self, room_id):
+        """Called when we initially fill out user_directory one room at a time
+        """
+        is_in_room = yield self.store.is_host_joined(room_id, self.server_name)
+        if not is_in_room:
+            return
+
+        is_public = yield self.store.is_room_world_readable_or_publicly_joinable(room_id)
+
+        users_with_profile = yield self.state.get_current_user_in_room(room_id)
+        user_ids = set(users_with_profile)
+        unhandled_users = user_ids - self.initially_handled_users
+
+        yield self.store.add_profiles_to_user_dir(
+            room_id, {
+                user_id: users_with_profile[user_id] for user_id in unhandled_users
+            }
+        )
+
+        self.initially_handled_users |= unhandled_users
+
+        if is_public:
+            yield self.store.add_users_to_public_room(
+                room_id,
+                user_ids=user_ids - self.initially_handled_users_in_public
+            )
+            self.initially_handled_users_in_public |= user_ids
+
+        # We now go and figure out the new users who share rooms with user entries
+        # We sleep aggressively here as otherwise it can starve resources.
+        # We also batch up inserts/updates, but try to avoid too many at once.
+        to_insert = set()
+        to_update = set()
+        count = 0
+        for user_id in user_ids:
+            if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
+                yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
+
+            if not self.is_mine_id(user_id):
+                count += 1
+                continue
+
+            if self.store.get_if_app_services_interested_in_user(user_id):
+                count += 1
+                continue
+
+            for other_user_id in user_ids:
+                if user_id == other_user_id:
+                    continue
+
+                if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
+                    yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
+                count += 1
+
+                user_set = (user_id, other_user_id)
+
+                if user_set in self.initially_handled_users_share_private_room:
+                    continue
+
+                if user_set in self.initially_handled_users_share:
+                    if is_public:
+                        continue
+                    to_update.add(user_set)
+                else:
+                    to_insert.add(user_set)
+
+                if is_public:
+                    self.initially_handled_users_share.add(user_set)
+                else:
+                    self.initially_handled_users_share_private_room.add(user_set)
+
+                if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE:
+                    yield self.store.add_users_who_share_room(
+                        room_id, not is_public, to_insert,
+                    )
+                    to_insert.clear()
+
+                if len(to_update) > self.INITIAL_ROOM_BATCH_SIZE:
+                    yield self.store.update_users_who_share_room(
+                        room_id, not is_public, to_update,
+                    )
+                    to_update.clear()
+
+        if to_insert:
+            yield self.store.add_users_who_share_room(
+                room_id, not is_public, to_insert,
+            )
+            to_insert.clear()
+
+        if to_update:
+            yield self.store.update_users_who_share_room(
+                room_id, not is_public, to_update,
+            )
+            to_update.clear()
+
+    @defer.inlineCallbacks
+    def _handle_deltas(self, deltas):
+        """Called with the state deltas to process
+        """
+        for delta in deltas:
+            typ = delta["type"]
+            state_key = delta["state_key"]
+            room_id = delta["room_id"]
+            event_id = delta["event_id"]
+            prev_event_id = delta["prev_event_id"]
+
+            logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
+
+            # For join rule and visibility changes we need to check if the room
+            # may have become public or not and add/remove the users in said room
+            if typ in (EventTypes.RoomHistoryVisibility, EventTypes.JoinRules):
+                yield self._handle_room_publicity_change(
+                    room_id, prev_event_id, event_id, typ,
+                )
+            elif typ == EventTypes.Member:
+                change = yield self._get_key_change(
+                    prev_event_id, event_id,
+                    key_name="membership",
+                    public_value=Membership.JOIN,
+                )
+
+                if change is None:
+                    # Handle any profile changes
+                    yield self._handle_profile_change(
+                        state_key, room_id, prev_event_id, event_id,
+                    )
+                    continue
+
+                if not change:
+                    # Need to check if the server left the room entirely, if so
+                    # we might need to remove all the users in that room
+                    is_in_room = yield self.store.is_host_joined(
+                        room_id, self.server_name,
+                    )
+                    if not is_in_room:
+                        logger.info("Server left room: %r", room_id)
+                        # Fetch all the users that we marked as being in user
+                        # directory due to being in the room and then check if
+                        # need to remove those users or not
+                        user_ids = yield self.store.get_users_in_dir_due_to_room(room_id)
+                        for user_id in user_ids:
+                            yield self._handle_remove_user(room_id, user_id)
+                        return
+                    else:
+                        logger.debug("Server is still in room: %r", room_id)
+
+                if change:  # The user joined
+                    event = yield self.store.get_event(event_id, allow_none=True)
+                    profile = ProfileInfo(
+                        avatar_url=event.content.get("avatar_url"),
+                        display_name=event.content.get("displayname"),
+                    )
+
+                    yield self._handle_new_user(room_id, state_key, profile)
+                else:  # The user left
+                    yield self._handle_remove_user(room_id, state_key)
+            else:
+                logger.debug("Ignoring irrelevant type: %r", typ)
+
+    @defer.inlineCallbacks
+    def _handle_room_publicity_change(self, room_id, prev_event_id, event_id, typ):
+        """Handle a room having potentially changed from/to world_readable/publically
+        joinable.
+
+        Args:
+            room_id (str)
+            prev_event_id (str|None): The previous event before the state change
+            event_id (str|None): The new event after the state change
+            typ (str): Type of the event
+        """
+        logger.debug("Handling change for %s: %s", typ, room_id)
+
+        if typ == EventTypes.RoomHistoryVisibility:
+            change = yield self._get_key_change(
+                prev_event_id, event_id,
+                key_name="history_visibility",
+                public_value="world_readable",
+            )
+        elif typ == EventTypes.JoinRules:
+            change = yield self._get_key_change(
+                prev_event_id, event_id,
+                key_name="join_rule",
+                public_value=JoinRules.PUBLIC,
+            )
+        else:
+            raise Exception("Invalid event type")
+        # If change is None, no change. True => become world_readable/public,
+        # False => was world_readable/public
+        if change is None:
+            logger.debug("No change")
+            return
+
+        # There's been a change to or from being world readable.
+
+        is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
+            room_id
+        )
+
+        logger.debug("Change: %r, is_public: %r", change, is_public)
+
+        if change and not is_public:
+            # If we became world readable but room isn't currently public then
+            # we ignore the change
+            return
+        elif not change and is_public:
+            # If we stopped being world readable but are still public,
+            # ignore the change
+            return
+
+        if change:
+            users_with_profile = yield self.state.get_current_user_in_room(room_id)
+            for user_id, profile in users_with_profile.iteritems():
+                yield self._handle_new_user(room_id, user_id, profile)
+        else:
+            users = yield self.store.get_users_in_public_due_to_room(room_id)
+            for user_id in users:
+                yield self._handle_remove_user(room_id, user_id)
+
+    @defer.inlineCallbacks
+    def _handle_local_user(self, user_id):
+        """Adds a new local roomless user into the user_directory_search table.
+        Used to populate up the user index when we have an
+        user_directory_search_all_users specified.
+        """
+        logger.debug("Adding new local user to dir, %r", user_id)
+
+        profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id))
+
+        row = yield self.store.get_user_in_directory(user_id)
+        if not row:
+            yield self.store.add_profiles_to_user_dir(None, {user_id: profile})
+
+    @defer.inlineCallbacks
+    def _handle_new_user(self, room_id, user_id, profile):
+        """Called when we might need to add user to directory
+
+        Args:
+            room_id (str): room_id that user joined or started being public
+            user_id (str)
+        """
+        logger.debug("Adding new user to dir, %r", user_id)
+
+        row = yield self.store.get_user_in_directory(user_id)
+        if not row:
+            yield self.store.add_profiles_to_user_dir(room_id, {user_id: profile})
+
+        is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
+            room_id
+        )
+
+        if is_public:
+            row = yield self.store.get_user_in_public_room(user_id)
+            if not row:
+                yield self.store.add_users_to_public_room(room_id, [user_id])
+        else:
+            logger.debug("Not adding new user to public dir, %r", user_id)
+
+        # Now we update users who share rooms with users. We do this by getting
+        # all the current users in the room and seeing which aren't already
+        # marked in the database as sharing with `user_id`
+
+        users_with_profile = yield self.state.get_current_user_in_room(room_id)
+
+        to_insert = set()
+        to_update = set()
+
+        is_appservice = self.store.get_if_app_services_interested_in_user(user_id)
+
+        # First, if they're our user then we need to update for every user
+        if self.is_mine_id(user_id) and not is_appservice:
+            # Returns a map of other_user_id -> shared_private. We only need
+            # to update mappings if for users that either don't share a room
+            # already (aren't in the map) or, if the room is private, those that
+            # only share a public room.
+            user_ids_shared = yield self.store.get_users_who_share_room_from_dir(
+                user_id
+            )
+
+            for other_user_id in users_with_profile:
+                if user_id == other_user_id:
+                    continue
+
+                shared_is_private = user_ids_shared.get(other_user_id)
+                if shared_is_private is True:
+                    # We've already marked in the database they share a private room
+                    continue
+                elif shared_is_private is False:
+                    # They already share a public room, so only update if this is
+                    # a private room
+                    if not is_public:
+                        to_update.add((user_id, other_user_id))
+                elif shared_is_private is None:
+                    # This is the first time they both share a room
+                    to_insert.add((user_id, other_user_id))
+
+        # Next we need to update for every local user in the room
+        for other_user_id in users_with_profile:
+            if user_id == other_user_id:
+                continue
+
+            is_appservice = self.store.get_if_app_services_interested_in_user(
+                other_user_id
+            )
+            if self.is_mine_id(other_user_id) and not is_appservice:
+                shared_is_private = yield self.store.get_if_users_share_a_room(
+                    other_user_id, user_id,
+                )
+                if shared_is_private is True:
+                    # We've already marked in the database they share a private room
+                    continue
+                elif shared_is_private is False:
+                    # They already share a public room, so only update if this is
+                    # a private room
+                    if not is_public:
+                        to_update.add((other_user_id, user_id))
+                elif shared_is_private is None:
+                    # This is the first time they both share a room
+                    to_insert.add((other_user_id, user_id))
+
+        if to_insert:
+            yield self.store.add_users_who_share_room(
+                room_id, not is_public, to_insert,
+            )
+
+        if to_update:
+            yield self.store.update_users_who_share_room(
+                room_id, not is_public, to_update,
+            )
+
+    @defer.inlineCallbacks
+    def _handle_remove_user(self, room_id, user_id):
+        """Called when we might need to remove user to directory
+
+        Args:
+            room_id (str): room_id that user left or stopped being public that
+            user_id (str)
+        """
+        logger.debug("Maybe removing user %r", user_id)
+
+        row = yield self.store.get_user_in_directory(user_id)
+        update_user_dir = row and row["room_id"] == room_id
+
+        row = yield self.store.get_user_in_public_room(user_id)
+        update_user_in_public = row and row["room_id"] == room_id
+
+        if (update_user_in_public or update_user_dir):
+            # XXX: Make this faster?
+            rooms = yield self.store.get_rooms_for_user(user_id)
+            for j_room_id in rooms:
+                if (not update_user_in_public and not update_user_dir):
+                    break
+
+                is_in_room = yield self.store.is_host_joined(
+                    j_room_id, self.server_name,
+                )
+
+                if not is_in_room:
+                    continue
+
+                if update_user_dir:
+                    update_user_dir = False
+                    yield self.store.update_user_in_user_dir(user_id, j_room_id)
+
+                is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
+                    j_room_id
+                )
+
+                if update_user_in_public and is_public:
+                    yield self.store.update_user_in_public_user_list(user_id, j_room_id)
+                    update_user_in_public = False
+
+        if update_user_dir:
+            yield self.store.remove_from_user_dir(user_id)
+        elif update_user_in_public:
+            yield self.store.remove_from_user_in_public_room(user_id)
+
+        # Now handle users_who_share_rooms.
+
+        # Get a list of user tuples that were in the DB due to this room and
+        # users (this includes tuples where the other user matches `user_id`)
+        user_tuples = yield self.store.get_users_in_share_dir_with_room_id(
+            user_id, room_id,
+        )
+
+        for user_id, other_user_id in user_tuples:
+            # For each user tuple get a list of rooms that they still share,
+            # trying to find a private room, and update the entry in the DB
+            rooms = yield self.store.get_rooms_in_common_for_users(user_id, other_user_id)
+
+            # If they dont share a room anymore, remove the mapping
+            if not rooms:
+                yield self.store.remove_user_who_share_room(
+                    user_id, other_user_id,
+                )
+                continue
+
+            found_public_share = None
+            for j_room_id in rooms:
+                is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
+                    j_room_id
+                )
+
+                if is_public:
+                    found_public_share = j_room_id
+                else:
+                    found_public_share = None
+                    yield self.store.update_users_who_share_room(
+                        room_id, not is_public, [(user_id, other_user_id)],
+                    )
+                    break
+
+            if found_public_share:
+                yield self.store.update_users_who_share_room(
+                    room_id, not is_public, [(user_id, other_user_id)],
+                )
+
+    @defer.inlineCallbacks
+    def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id):
+        """Check member event changes for any profile changes and update the
+        database if there are.
+        """
+        if not prev_event_id or not event_id:
+            return
+
+        prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+        event = yield self.store.get_event(event_id, allow_none=True)
+
+        if not prev_event or not event:
+            return
+
+        if event.membership != Membership.JOIN:
+            return
+
+        prev_name = prev_event.content.get("displayname")
+        new_name = event.content.get("displayname")
+
+        prev_avatar = prev_event.content.get("avatar_url")
+        new_avatar = event.content.get("avatar_url")
+
+        if prev_name != new_name or prev_avatar != new_avatar:
+            yield self.store.update_profile_in_user_dir(
+                user_id, new_name, new_avatar, room_id,
+            )
+
+    @defer.inlineCallbacks
+    def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
+        """Given two events check if the `key_name` field in content changed
+        from not matching `public_value` to doing so.
+
+        For example, check if `history_visibility` (`key_name`) changed from
+        `shared` to `world_readable` (`public_value`).
+
+        Returns:
+            None if the field in the events either both match `public_value`
+            or if neither do, i.e. there has been no change.
+            True if it didnt match `public_value` but now does
+            False if it did match `public_value` but now doesn't
+        """
+        prev_event = None
+        event = None
+        if prev_event_id:
+            prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+
+        if event_id:
+            event = yield self.store.get_event(event_id, allow_none=True)
+
+        if not event and not prev_event:
+            logger.debug("Neither event exists: %r %r", prev_event_id, event_id)
+            defer.returnValue(None)
+
+        prev_value = None
+        value = None
+
+        if prev_event:
+            prev_value = prev_event.content.get(key_name)
+
+        if event:
+            value = event.content.get(key_name)
+
+        logger.debug("prev_value: %r -> value: %r", prev_value, value)
+
+        if value == public_value and prev_value != public_value:
+            defer.returnValue(True)
+        elif value != public_value and prev_value == public_value:
+            defer.returnValue(False)
+        else:
+            defer.returnValue(None)
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index bfebb0f644..054372e179 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,3 +13,24 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from twisted.internet.defer import CancelledError
+from twisted.python import failure
+
+from synapse.api.errors import SynapseError
+
+
+class RequestTimedOutError(SynapseError):
+    """Exception representing timeout of an outbound request"""
+    def __init__(self):
+        super(RequestTimedOutError, self).__init__(504, "Timed out")
+
+
+def cancelled_to_request_timed_out_error(value, timeout):
+    """Turns CancelledErrors into RequestTimedOutErrors.
+
+    For use with async.add_timeout_to_deferred
+    """
+    if isinstance(value, failure.Failure):
+        value.trap(CancelledError)
+        raise RequestTimedOutError()
+    return value
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
new file mode 100644
index 0000000000..343e932cb1
--- /dev/null
+++ b/synapse/http/additional_resource.py
@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.http.server import wrap_request_handler
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
+
+
+class AdditionalResource(Resource):
+    """Resource wrapper for additional_resources
+
+    If the user has configured additional_resources, we need to wrap the
+    handler class with a Resource so that we can map it into the resource tree.
+
+    This class is also where we wrap the request handler with logging, metrics,
+    and exception handling.
+    """
+    def __init__(self, hs, handler):
+        """Initialise AdditionalResource
+
+        The ``handler`` should return a deferred which completes when it has
+        done handling the request. It should write a response with
+        ``request.write()``, and call ``request.finish()``.
+
+        Args:
+            hs (synapse.server.HomeServer): homeserver
+            handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
+                function to be called to handle the request.
+        """
+        Resource.__init__(self)
+        self._handler = handler
+
+        # these are required by the request_handler wrapper
+        self.version_string = hs.version_string
+        self.clock = hs.get_clock()
+
+    def render(self, request):
+        self._async_render(request)
+        return NOT_DONE_YET
+
+    @wrap_request_handler
+    def _async_render(self, request):
+        return self._handler(request)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index ca2f770f5d..70a19d9b74 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,9 +17,12 @@ from OpenSSL import SSL
 from OpenSSL.SSL import VERIFY_NONE
 
 from synapse.api.errors import (
-    CodeMessageException, SynapseError, Codes,
+    CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
 )
-from synapse.util.logcontext import preserve_context_over_fn
+from synapse.http import cancelled_to_request_timed_out_error
+from synapse.util.async import add_timeout_to_deferred
+from synapse.util.caches import CACHE_SIZE_FACTOR
+from synapse.util.logcontext import make_deferred_yieldable
 import synapse.metrics
 from synapse.http.endpoint import SpiderEndpoint
 
@@ -29,13 +33,14 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
 from twisted.web.client import (
     BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
     readBody, PartialDownloadError,
+    HTTPConnectionPool,
 )
 from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
 from twisted.web.http import PotentialDataLoss
 from twisted.web.http_headers import Headers
 from twisted.web._newclient import ResponseDone
 
-from StringIO import StringIO
+from six import StringIO
 
 import simplejson as json
 import logging
@@ -63,92 +68,139 @@ class SimpleHttpClient(object):
     """
     def __init__(self, hs):
         self.hs = hs
+
+        pool = HTTPConnectionPool(reactor)
+
+        # the pusher makes lots of concurrent SSL connections to sygnal, and
+        # tends to do so in batches, so we need to allow the pool to keep lots
+        # of idle connections around.
+        pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
+        pool.cachedConnectionTimeout = 2 * 60
+
         # The default context factory in Twisted 14.0.0 (which we require) is
         # BrowserLikePolicyForHTTPS which will do regular cert validation
         # 'like a browser'
         self.agent = Agent(
             reactor,
             connectTimeout=15,
-            contextFactory=hs.get_http_client_context_factory()
+            contextFactory=hs.get_http_client_context_factory(),
+            pool=pool,
         )
         self.user_agent = hs.version_string
+        self.clock = hs.get_clock()
         if hs.config.user_agent_suffix:
             self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,)
 
+    @defer.inlineCallbacks
     def request(self, method, uri, *args, **kwargs):
         # A small wrapper around self.agent.request() so we can easily attach
         # counters to it
         outgoing_requests_counter.inc(method)
-        d = preserve_context_over_fn(
-            self.agent.request,
-            method, uri, *args, **kwargs
-        )
 
         logger.info("Sending request %s %s", method, uri)
 
-        def _cb(response):
+        try:
+            request_deferred = self.agent.request(
+                method, uri, *args, **kwargs
+            )
+            add_timeout_to_deferred(
+                request_deferred,
+                60, cancelled_to_request_timed_out_error,
+            )
+            response = yield make_deferred_yieldable(request_deferred)
+
             incoming_responses_counter.inc(method, response.code)
             logger.info(
                 "Received response to  %s %s: %s",
                 method, uri, response.code
             )
-            return response
-
-        def _eb(failure):
+            defer.returnValue(response)
+        except Exception as e:
             incoming_responses_counter.inc(method, "ERR")
             logger.info(
                 "Error sending request to  %s %s: %s %s",
-                method, uri, failure.type, failure.getErrorMessage()
+                method, uri, type(e).__name__, e.message
             )
-            return failure
+            raise e
 
-        d.addCallbacks(_cb, _eb)
+    @defer.inlineCallbacks
+    def post_urlencoded_get_json(self, uri, args={}, headers=None):
+        """
+        Args:
+            uri (str):
+            args (dict[str, str|List[str]]): query params
+            headers (dict[str, List[str]]|None): If not None, a map from
+               header name to a list of values for that header
 
-        return d
+        Returns:
+            Deferred[object]: parsed json
+        """
 
-    @defer.inlineCallbacks
-    def post_urlencoded_get_json(self, uri, args={}):
         # TODO: Do we ever want to log message contents?
         logger.debug("post_urlencoded_get_json args: %s", args)
 
         query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
 
+        actual_headers = {
+            b"Content-Type": [b"application/x-www-form-urlencoded"],
+            b"User-Agent": [self.user_agent],
+        }
+        if headers:
+            actual_headers.update(headers)
+
         response = yield self.request(
             "POST",
             uri.encode("ascii"),
-            headers=Headers({
-                b"Content-Type": [b"application/x-www-form-urlencoded"],
-                b"User-Agent": [self.user_agent],
-            }),
+            headers=Headers(actual_headers),
             bodyProducer=FileBodyProducer(StringIO(query_bytes))
         )
 
-        body = yield preserve_context_over_fn(readBody, response)
+        body = yield make_deferred_yieldable(readBody(response))
 
         defer.returnValue(json.loads(body))
 
     @defer.inlineCallbacks
-    def post_json_get_json(self, uri, post_json):
+    def post_json_get_json(self, uri, post_json, headers=None):
+        """
+
+        Args:
+            uri (str):
+            post_json (object):
+            headers (dict[str, List[str]]|None): If not None, a map from
+               header name to a list of values for that header
+
+        Returns:
+            Deferred[object]: parsed json
+        """
         json_str = encode_canonical_json(post_json)
 
         logger.debug("HTTP POST %s -> %s", json_str, uri)
 
+        actual_headers = {
+            b"Content-Type": [b"application/json"],
+            b"User-Agent": [self.user_agent],
+        }
+        if headers:
+            actual_headers.update(headers)
+
         response = yield self.request(
             "POST",
             uri.encode("ascii"),
-            headers=Headers({
-                b"Content-Type": [b"application/json"],
-                b"User-Agent": [self.user_agent],
-            }),
+            headers=Headers(actual_headers),
             bodyProducer=FileBodyProducer(StringIO(json_str))
         )
 
-        body = yield preserve_context_over_fn(readBody, response)
+        body = yield make_deferred_yieldable(readBody(response))
+
+        if 200 <= response.code < 300:
+            defer.returnValue(json.loads(body))
+        else:
+            raise self._exceptionFromFailedRequest(response, body)
 
         defer.returnValue(json.loads(body))
 
     @defer.inlineCallbacks
-    def get_json(self, uri, args={}):
+    def get_json(self, uri, args={}, headers=None):
         """ Gets some json from the given URI.
 
         Args:
@@ -157,6 +209,8 @@ class SimpleHttpClient(object):
                 None.
                 **Note**: The value of each key is assumed to be an iterable
                 and *not* a string.
+            headers (dict[str, List[str]]|None): If not None, a map from
+               header name to a list of values for that header
         Returns:
             Deferred: Succeeds when we get *any* 2xx HTTP response, with the
             HTTP body as JSON.
@@ -164,11 +218,14 @@ class SimpleHttpClient(object):
             On a non-2xx HTTP response. The response body will be used as the
             error message.
         """
-        body = yield self.get_raw(uri, args)
-        defer.returnValue(json.loads(body))
+        try:
+            body = yield self.get_raw(uri, args, headers=headers)
+            defer.returnValue(json.loads(body))
+        except CodeMessageException as e:
+            raise self._exceptionFromFailedRequest(e.code, e.msg)
 
     @defer.inlineCallbacks
-    def put_json(self, uri, json_body, args={}):
+    def put_json(self, uri, json_body, args={}, headers=None):
         """ Puts some json to the given URI.
 
         Args:
@@ -178,6 +235,8 @@ class SimpleHttpClient(object):
                 None.
                 **Note**: The value of each key is assumed to be an iterable
                 and *not* a string.
+            headers (dict[str, List[str]]|None): If not None, a map from
+               header name to a list of values for that header
         Returns:
             Deferred: Succeeds when we get *any* 2xx HTTP response, with the
             HTTP body as JSON.
@@ -190,17 +249,21 @@ class SimpleHttpClient(object):
 
         json_str = encode_canonical_json(json_body)
 
+        actual_headers = {
+            b"Content-Type": [b"application/json"],
+            b"User-Agent": [self.user_agent],
+        }
+        if headers:
+            actual_headers.update(headers)
+
         response = yield self.request(
             "PUT",
             uri.encode("ascii"),
-            headers=Headers({
-                b"User-Agent": [self.user_agent],
-                "Content-Type": ["application/json"]
-            }),
+            headers=Headers(actual_headers),
             bodyProducer=FileBodyProducer(StringIO(json_str))
         )
 
-        body = yield preserve_context_over_fn(readBody, response)
+        body = yield make_deferred_yieldable(readBody(response))
 
         if 200 <= response.code < 300:
             defer.returnValue(json.loads(body))
@@ -211,7 +274,7 @@ class SimpleHttpClient(object):
             raise CodeMessageException(response.code, body)
 
     @defer.inlineCallbacks
-    def get_raw(self, uri, args={}):
+    def get_raw(self, uri, args={}, headers=None):
         """ Gets raw text from the given URI.
 
         Args:
@@ -220,6 +283,8 @@ class SimpleHttpClient(object):
                 None.
                 **Note**: The value of each key is assumed to be an iterable
                 and *not* a string.
+            headers (dict[str, List[str]]|None): If not None, a map from
+               header name to a list of values for that header
         Returns:
             Deferred: Succeeds when we get *any* 2xx HTTP response, with the
             HTTP body at text.
@@ -231,46 +296,65 @@ class SimpleHttpClient(object):
             query_bytes = urllib.urlencode(args, True)
             uri = "%s?%s" % (uri, query_bytes)
 
+        actual_headers = {
+            b"User-Agent": [self.user_agent],
+        }
+        if headers:
+            actual_headers.update(headers)
+
         response = yield self.request(
             "GET",
             uri.encode("ascii"),
-            headers=Headers({
-                b"User-Agent": [self.user_agent],
-            })
+            headers=Headers(actual_headers),
         )
 
-        body = yield preserve_context_over_fn(readBody, response)
+        body = yield make_deferred_yieldable(readBody(response))
 
         if 200 <= response.code < 300:
             defer.returnValue(body)
         else:
             raise CodeMessageException(response.code, body)
 
+    def _exceptionFromFailedRequest(self, response, body):
+        try:
+            jsonBody = json.loads(body)
+            errcode = jsonBody['errcode']
+            error = jsonBody['error']
+            return MatrixCodeMessageException(response.code, error, errcode)
+        except (ValueError, KeyError):
+            return CodeMessageException(response.code, body)
+
     # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
     # The two should be factored out.
 
     @defer.inlineCallbacks
-    def get_file(self, url, output_stream, max_size=None):
+    def get_file(self, url, output_stream, max_size=None, headers=None):
         """GETs a file from a given URL
         Args:
             url (str): The URL to GET
             output_stream (file): File to write the response body to.
+            headers (dict[str, List[str]]|None): If not None, a map from
+               header name to a list of values for that header
         Returns:
             A (int,dict,string,int) tuple of the file length, dict of the response
             headers, absolute URI of the response and HTTP response code.
         """
 
+        actual_headers = {
+            b"User-Agent": [self.user_agent],
+        }
+        if headers:
+            actual_headers.update(headers)
+
         response = yield self.request(
             "GET",
             url.encode("ascii"),
-            headers=Headers({
-                b"User-Agent": [self.user_agent],
-            })
+            headers=Headers(actual_headers),
         )
 
-        headers = dict(response.headers.getAllRawHeaders())
+        resp_headers = dict(response.headers.getAllRawHeaders())
 
-        if 'Content-Length' in headers and headers['Content-Length'] > max_size:
+        if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size:
             logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
             raise SynapseError(
                 502,
@@ -291,10 +375,9 @@ class SimpleHttpClient(object):
         # straight back in again
 
         try:
-            length = yield preserve_context_over_fn(
-                _readBodyToFile,
-                response, output_stream, max_size
-            )
+            length = yield make_deferred_yieldable(_readBodyToFile(
+                response, output_stream, max_size,
+            ))
         except Exception as e:
             logger.exception("Failed to download body")
             raise SynapseError(
@@ -303,7 +386,9 @@ class SimpleHttpClient(object):
                 Codes.UNKNOWN,
             )
 
-        defer.returnValue((length, headers, response.request.absoluteURI, response.code))
+        defer.returnValue(
+            (length, resp_headers, response.request.absoluteURI, response.code),
+        )
 
 
 # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
@@ -371,7 +456,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
         )
 
         try:
-            body = yield preserve_context_over_fn(readBody, response)
+            body = yield make_deferred_yieldable(readBody(response))
             defer.returnValue(body)
         except PartialDownloadError as e:
             # twisted dislikes google's response, no content length.
@@ -422,7 +507,7 @@ class SpiderHttpClient(SimpleHttpClient):
                     reactor,
                     SpiderEndpointFactory(hs)
                 )
-            ), [('gzip', GzipDecoder)]
+            ), [(b'gzip', GzipDecoder)]
         )
         # We could look like Chrome:
         # self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko)
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 564ae4c10d..87a482650d 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -12,7 +12,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
 from twisted.internet import defer, reactor
 from twisted.internet.error import ConnectError
@@ -30,7 +29,10 @@ logger = logging.getLogger(__name__)
 
 SERVER_CACHE = {}
 
-
+# our record of an individual server which can be tried to reach a destination.
+#
+# "host" is the hostname acquired from the SRV record. Except when there's
+# no SRV record, in which case it is the original hostname.
 _Server = collections.namedtuple(
     "_Server", "priority weight host port expires"
 )
@@ -224,9 +226,10 @@ class SRVClientEndpoint(object):
                 return self.default_server
             else:
                 raise ConnectError(
-                    "Not server available for %s" % self.service_name
+                    "No server available for %s" % self.service_name
                 )
 
+        # look for all servers with the same priority
         min_priority = self.servers[0].priority
         weight_indexes = list(
             (index, server.weight + 1)
@@ -236,11 +239,22 @@ class SRVClientEndpoint(object):
 
         total_weight = sum(weight for index, weight in weight_indexes)
         target_weight = random.randint(0, total_weight)
-
         for index, weight in weight_indexes:
             target_weight -= weight
             if target_weight <= 0:
                 server = self.servers[index]
+                # XXX: this looks totally dubious:
+                #
+                # (a) we never reuse a server until we have been through
+                #     all of the servers at the same priority, so if the
+                #     weights are A: 100, B:1, we always do ABABAB instead of
+                #     AAAA...AAAB (approximately).
+                #
+                # (b) After using all the servers at the lowest priority,
+                #     we move onto the next priority. We should only use the
+                #     second priority if servers at the top priority are
+                #     unreachable.
+                #
                 del self.servers[index]
                 self.used_servers.append(server)
                 return server
@@ -277,7 +291,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
         if (len(answers) == 1
                 and answers[0].type == dns.SRV
                 and answers[0].payload
-                and answers[0].payload.target == dns.Name('.')):
+                and answers[0].payload.target == dns.Name(b'.')):
             raise ConnectError("Service %s unavailable" % service_name)
 
         for answer in answers:
@@ -285,26 +299,14 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
                 continue
 
             payload = answer.payload
-            host = str(payload.target)
-            srv_ttl = answer.ttl
-
-            try:
-                answers, _, _ = yield dns_client.lookupAddress(host)
-            except DNSNameError:
-                continue
 
-            for answer in answers:
-                if answer.type == dns.A and answer.payload:
-                    ip = answer.payload.dottedQuad()
-                    host_ttl = min(srv_ttl, answer.ttl)
-
-                    servers.append(_Server(
-                        host=ip,
-                        port=int(payload.port),
-                        priority=int(payload.priority),
-                        weight=int(payload.weight),
-                        expires=int(clock.time()) + host_ttl,
-                    ))
+            servers.append(_Server(
+                host=str(payload.target),
+                port=int(payload.port),
+                priority=int(payload.priority),
+                weight=int(payload.weight),
+                expires=int(clock.time()) + answer.ttl,
+            ))
 
         servers.sort()
         cache[service_name] = list(servers)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 78b92cef36..4b2b85464d 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,23 +13,24 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-
 from twisted.internet import defer, reactor, protocol
 from twisted.internet.error import DNSLookupError
 from twisted.web.client import readBody, HTTPConnectionPool, Agent
 from twisted.web.http_headers import Headers
 from twisted.web._newclient import ResponseDone
 
+from synapse.http import cancelled_to_request_timed_out_error
 from synapse.http.endpoint import matrix_federation_endpoint
-from synapse.util.async import sleep
-from synapse.util.logcontext import preserve_context_over_fn
 import synapse.metrics
+from synapse.util.async import sleep, add_timeout_to_deferred
+from synapse.util import logcontext
+from synapse.util.logcontext import make_deferred_yieldable
+import synapse.util.retryutils
 
 from canonicaljson import encode_canonical_json
 
 from synapse.api.errors import (
-    SynapseError, Codes, HttpResponseException,
+    SynapseError, Codes, HttpResponseException, FederationDeniedError,
 )
 
 from signedjson.sign import sign_json
@@ -39,8 +41,7 @@ import logging
 import random
 import sys
 import urllib
-import urlparse
-
+from six.moves.urllib import parse as urlparse
 
 logger = logging.getLogger(__name__)
 outbound_logger = logging.getLogger("synapse.http.outbound")
@@ -94,6 +95,7 @@ class MatrixFederationHttpClient(object):
             reactor, MatrixFederationEndpointFactory(hs), pool=pool
         )
         self.clock = hs.get_clock()
+        self._store = hs.get_datastore()
         self.version_string = hs.version_string
         self._next_id = 1
 
@@ -103,123 +105,161 @@ class MatrixFederationHttpClient(object):
         )
 
     @defer.inlineCallbacks
-    def _create_request(self, destination, method, path_bytes,
-                        body_callback, headers_dict={}, param_bytes=b"",
-                        query_bytes=b"", retry_on_dns_fail=True,
-                        timeout=None, long_retries=False):
-        """ Creates and sends a request to the given url
-        """
-        headers_dict[b"User-Agent"] = [self.version_string]
-        headers_dict[b"Host"] = [destination]
+    def _request(self, destination, method, path,
+                 body_callback, headers_dict={}, param_bytes=b"",
+                 query_bytes=b"", retry_on_dns_fail=True,
+                 timeout=None, long_retries=False,
+                 ignore_backoff=False,
+                 backoff_on_404=False):
+        """ Creates and sends a request to the given server
+        Args:
+            destination (str): The remote server to send the HTTP request to.
+            method (str): HTTP method
+            path (str): The HTTP path
+            ignore_backoff (bool): true to ignore the historical backoff data
+                and try the request anyway.
+            backoff_on_404 (bool): Back off if we get a 404
 
-        url_bytes = self._create_url(
-            destination, path_bytes, param_bytes, query_bytes
-        )
+        Returns:
+            Deferred: resolves with the http response object on success.
 
-        txn_id = "%s-O-%s" % (method, self._next_id)
-        self._next_id = (self._next_id + 1) % (sys.maxint - 1)
+            Fails with ``HTTPRequestException``: if we get an HTTP response
+                code >= 300.
 
-        outbound_logger.info(
-            "{%s} [%s] Sending request: %s %s",
-            txn_id, destination, method, url_bytes
-        )
+            Fails with ``NotRetryingDestination`` if we are not yet ready
+                to retry this server.
 
-        # XXX: Would be much nicer to retry only at the transaction-layer
-        # (once we have reliable transactions in place)
-        if long_retries:
-            retries_left = MAX_LONG_RETRIES
-        else:
-            retries_left = MAX_SHORT_RETRIES
+            Fails with ``FederationDeniedError`` if this destination
+                is not on our federation whitelist
 
-        http_url_bytes = urlparse.urlunparse(
-            ("", "", path_bytes, param_bytes, query_bytes, "")
+            (May also fail with plenty of other Exceptions for things like DNS
+                failures, connection failures, SSL failures.)
+        """
+        if (
+            self.hs.config.federation_domain_whitelist and
+            destination not in self.hs.config.federation_domain_whitelist
+        ):
+            raise FederationDeniedError(destination)
+
+        limiter = yield synapse.util.retryutils.get_retry_limiter(
+            destination,
+            self.clock,
+            self._store,
+            backoff_on_404=backoff_on_404,
+            ignore_backoff=ignore_backoff,
         )
 
-        log_result = None
-        try:
-            while True:
-                producer = None
-                if body_callback:
-                    producer = body_callback(method, http_url_bytes, headers_dict)
-
-                try:
-                    def send_request():
-                        request_deferred = preserve_context_over_fn(
-                            self.agent.request,
+        destination = destination.encode("ascii")
+        path_bytes = path.encode("ascii")
+        with limiter:
+            headers_dict[b"User-Agent"] = [self.version_string]
+            headers_dict[b"Host"] = [destination]
+
+            url_bytes = self._create_url(
+                destination, path_bytes, param_bytes, query_bytes
+            )
+
+            txn_id = "%s-O-%s" % (method, self._next_id)
+            self._next_id = (self._next_id + 1) % (sys.maxint - 1)
+
+            outbound_logger.info(
+                "{%s} [%s] Sending request: %s %s",
+                txn_id, destination, method, url_bytes
+            )
+
+            # XXX: Would be much nicer to retry only at the transaction-layer
+            # (once we have reliable transactions in place)
+            if long_retries:
+                retries_left = MAX_LONG_RETRIES
+            else:
+                retries_left = MAX_SHORT_RETRIES
+
+            http_url_bytes = urlparse.urlunparse(
+                ("", "", path_bytes, param_bytes, query_bytes, "")
+            )
+
+            log_result = None
+            try:
+                while True:
+                    producer = None
+                    if body_callback:
+                        producer = body_callback(method, http_url_bytes, headers_dict)
+
+                    try:
+                        request_deferred = self.agent.request(
                             method,
                             url_bytes,
                             Headers(headers_dict),
                             producer
                         )
-
-                        return self.clock.time_bound_deferred(
+                        add_timeout_to_deferred(
+                            request_deferred,
+                            timeout / 1000. if timeout else 60,
+                            cancelled_to_request_timed_out_error,
+                        )
+                        response = yield make_deferred_yieldable(
                             request_deferred,
-                            time_out=timeout / 1000. if timeout else 60,
                         )
 
-                    response = yield preserve_context_over_fn(send_request)
+                        log_result = "%d %s" % (response.code, response.phrase,)
+                        break
+                    except Exception as e:
+                        if not retry_on_dns_fail and isinstance(e, DNSLookupError):
+                            logger.warn(
+                                "DNS Lookup failed to %s with %s",
+                                destination,
+                                e
+                            )
+                            log_result = "DNS Lookup failed to %s with %s" % (
+                                destination, e
+                            )
+                            raise
 
-                    log_result = "%d %s" % (response.code, response.phrase,)
-                    break
-                except Exception as e:
-                    if not retry_on_dns_fail and isinstance(e, DNSLookupError):
                         logger.warn(
-                            "DNS Lookup failed to %s with %s",
+                            "{%s} Sending request failed to %s: %s %s: %s",
+                            txn_id,
                             destination,
-                            e
-                        )
-                        log_result = "DNS Lookup failed to %s with %s" % (
-                            destination, e
+                            method,
+                            url_bytes,
+                            _flatten_response_never_received(e),
                         )
-                        raise
-
-                    logger.warn(
-                        "{%s} Sending request failed to %s: %s %s: %s - %s",
-                        txn_id,
-                        destination,
-                        method,
-                        url_bytes,
-                        type(e).__name__,
-                        _flatten_response_never_received(e),
-                    )
-
-                    log_result = "%s - %s" % (
-                        type(e).__name__, _flatten_response_never_received(e),
-                    )
-
-                    if retries_left and not timeout:
-                        if long_retries:
-                            delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left)
-                            delay = min(delay, 60)
-                            delay *= random.uniform(0.8, 1.4)
+
+                        log_result = _flatten_response_never_received(e)
+
+                        if retries_left and not timeout:
+                            if long_retries:
+                                delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left)
+                                delay = min(delay, 60)
+                                delay *= random.uniform(0.8, 1.4)
+                            else:
+                                delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left)
+                                delay = min(delay, 2)
+                                delay *= random.uniform(0.8, 1.4)
+
+                            yield sleep(delay)
+                            retries_left -= 1
                         else:
-                            delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left)
-                            delay = min(delay, 2)
-                            delay *= random.uniform(0.8, 1.4)
-
-                        yield sleep(delay)
-                        retries_left -= 1
-                    else:
-                        raise
-        finally:
-            outbound_logger.info(
-                "{%s} [%s] Result: %s",
-                txn_id,
-                destination,
-                log_result,
-            )
+                            raise
+            finally:
+                outbound_logger.info(
+                    "{%s} [%s] Result: %s",
+                    txn_id,
+                    destination,
+                    log_result,
+                )
 
-        if 200 <= response.code < 300:
-            pass
-        else:
-            # :'(
-            # Update transactions table?
-            body = yield preserve_context_over_fn(readBody, response)
-            raise HttpResponseException(
-                response.code, response.phrase, body
-            )
+            if 200 <= response.code < 300:
+                pass
+            else:
+                # :'(
+                # Update transactions table?
+                with logcontext.PreserveLoggingContext():
+                    body = yield readBody(response)
+                raise HttpResponseException(
+                    response.code, response.phrase, body
+                )
 
-        defer.returnValue(response)
+            defer.returnValue(response)
 
     def sign_request(self, destination, method, url_bytes, headers_dict,
                      content=None):
@@ -247,14 +287,18 @@ class MatrixFederationHttpClient(object):
         headers_dict[b"Authorization"] = auth_headers
 
     @defer.inlineCallbacks
-    def put_json(self, destination, path, data={}, json_data_callback=None,
-                 long_retries=False, timeout=None):
+    def put_json(self, destination, path, args={}, data={},
+                 json_data_callback=None,
+                 long_retries=False, timeout=None,
+                 ignore_backoff=False,
+                 backoff_on_404=False):
         """ Sends the specifed json data using PUT
 
         Args:
             destination (str): The remote server to send the HTTP request
                 to.
             path (str): The HTTP path.
+            args (dict): query params
             data (dict): A dict containing the data that will be used as
                 the request body. This will be encoded as JSON.
             json_data_callback (callable): A callable returning the dict to
@@ -263,11 +307,24 @@ class MatrixFederationHttpClient(object):
                 retry for a short or long time.
             timeout(int): How long to try (in ms) the destination for before
                 giving up. None indicates no timeout.
+            ignore_backoff (bool): true to ignore the historical backoff data
+                and try the request anyway.
+            backoff_on_404 (bool): True if we should count a 404 response as
+                a failure of the server (and should therefore back off future
+                requests)
 
         Returns:
             Deferred: Succeeds when we get a 2xx HTTP response. The result
-            will be the decoded JSON body. On a 4xx or 5xx error response a
-            CodeMessageException is raised.
+            will be the decoded JSON body.
+
+            Fails with ``HTTPRequestException`` if we get an HTTP response
+            code >= 300.
+
+            Fails with ``NotRetryingDestination`` if we are not yet ready
+            to retry this server.
+
+            Fails with ``FederationDeniedError`` if this destination
+            is not on our federation whitelist
         """
 
         if not json_data_callback:
@@ -282,26 +339,30 @@ class MatrixFederationHttpClient(object):
             producer = _JsonProducer(json_data)
             return producer
 
-        response = yield self._create_request(
-            destination.encode("ascii"),
+        response = yield self._request(
+            destination,
             "PUT",
-            path.encode("ascii"),
+            path,
             body_callback=body_callback,
             headers_dict={"Content-Type": ["application/json"]},
+            query_bytes=encode_query_args(args),
             long_retries=long_retries,
             timeout=timeout,
+            ignore_backoff=ignore_backoff,
+            backoff_on_404=backoff_on_404,
         )
 
         if 200 <= response.code < 300:
             # We need to update the transactions table to say it was sent?
             check_content_type_is_json(response.headers)
 
-        body = yield preserve_context_over_fn(readBody, response)
+        with logcontext.PreserveLoggingContext():
+            body = yield readBody(response)
         defer.returnValue(json.loads(body))
 
     @defer.inlineCallbacks
     def post_json(self, destination, path, data={}, long_retries=False,
-                  timeout=None):
+                  timeout=None, ignore_backoff=False, args={}):
         """ Sends the specifed json data using POST
 
         Args:
@@ -314,11 +375,21 @@ class MatrixFederationHttpClient(object):
                 retry for a short or long time.
             timeout(int): How long to try (in ms) the destination for before
                 giving up. None indicates no timeout.
-
+            ignore_backoff (bool): true to ignore the historical backoff data and
+                try the request anyway.
+            args (dict): query params
         Returns:
             Deferred: Succeeds when we get a 2xx HTTP response. The result
-            will be the decoded JSON body. On a 4xx or 5xx error response a
-            CodeMessageException is raised.
+            will be the decoded JSON body.
+
+            Fails with ``HTTPRequestException`` if we get an HTTP response
+            code >= 300.
+
+            Fails with ``NotRetryingDestination`` if we are not yet ready
+            to retry this server.
+
+            Fails with ``FederationDeniedError`` if this destination
+            is not on our federation whitelist
         """
 
         def body_callback(method, url_bytes, headers_dict):
@@ -327,27 +398,30 @@ class MatrixFederationHttpClient(object):
             )
             return _JsonProducer(data)
 
-        response = yield self._create_request(
-            destination.encode("ascii"),
+        response = yield self._request(
+            destination,
             "POST",
-            path.encode("ascii"),
+            path,
+            query_bytes=encode_query_args(args),
             body_callback=body_callback,
             headers_dict={"Content-Type": ["application/json"]},
             long_retries=long_retries,
             timeout=timeout,
+            ignore_backoff=ignore_backoff,
         )
 
         if 200 <= response.code < 300:
             # We need to update the transactions table to say it was sent?
             check_content_type_is_json(response.headers)
 
-        body = yield preserve_context_over_fn(readBody, response)
+        with logcontext.PreserveLoggingContext():
+            body = yield readBody(response)
 
         defer.returnValue(json.loads(body))
 
     @defer.inlineCallbacks
     def get_json(self, destination, path, args={}, retry_on_dns_fail=True,
-                 timeout=None):
+                 timeout=None, ignore_backoff=False):
         """ GETs some json from the given host homeserver and path
 
         Args:
@@ -359,57 +433,122 @@ class MatrixFederationHttpClient(object):
             timeout (int): How long to try (in ms) the destination for before
                 giving up. None indicates no timeout and that the request will
                 be retried.
+            ignore_backoff (bool): true to ignore the historical backoff data
+                and try the request anyway.
         Returns:
-            Deferred: Succeeds when we get *any* HTTP response.
+            Deferred: Succeeds when we get a 2xx HTTP response. The result
+            will be the decoded JSON body.
+
+            Fails with ``HTTPRequestException`` if we get an HTTP response
+            code >= 300.
+
+            Fails with ``NotRetryingDestination`` if we are not yet ready
+            to retry this server.
 
-            The result of the deferred is a tuple of `(code, response)`,
-            where `response` is a dict representing the decoded JSON body.
+            Fails with ``FederationDeniedError`` if this destination
+            is not on our federation whitelist
         """
         logger.debug("get_json args: %s", args)
 
-        encoded_args = {}
-        for k, vs in args.items():
-            if isinstance(vs, basestring):
-                vs = [vs]
-            encoded_args[k] = [v.encode("UTF-8") for v in vs]
-
-        query_bytes = urllib.urlencode(encoded_args, True)
         logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
 
         def body_callback(method, url_bytes, headers_dict):
             self.sign_request(destination, method, url_bytes, headers_dict)
             return None
 
-        response = yield self._create_request(
-            destination.encode("ascii"),
+        response = yield self._request(
+            destination,
             "GET",
-            path.encode("ascii"),
-            query_bytes=query_bytes,
+            path,
+            query_bytes=encode_query_args(args),
             body_callback=body_callback,
             retry_on_dns_fail=retry_on_dns_fail,
             timeout=timeout,
+            ignore_backoff=ignore_backoff,
         )
 
         if 200 <= response.code < 300:
             # We need to update the transactions table to say it was sent?
             check_content_type_is_json(response.headers)
 
-        body = yield preserve_context_over_fn(readBody, response)
+        with logcontext.PreserveLoggingContext():
+            body = yield readBody(response)
+
+        defer.returnValue(json.loads(body))
+
+    @defer.inlineCallbacks
+    def delete_json(self, destination, path, long_retries=False,
+                    timeout=None, ignore_backoff=False, args={}):
+        """Send a DELETE request to the remote expecting some json response
+
+        Args:
+            destination (str): The remote server to send the HTTP request
+                to.
+            path (str): The HTTP path.
+            long_retries (bool): A boolean that indicates whether we should
+                retry for a short or long time.
+            timeout(int): How long to try (in ms) the destination for before
+                giving up. None indicates no timeout.
+            ignore_backoff (bool): true to ignore the historical backoff data and
+                try the request anyway.
+        Returns:
+            Deferred: Succeeds when we get a 2xx HTTP response. The result
+            will be the decoded JSON body.
+
+            Fails with ``HTTPRequestException`` if we get an HTTP response
+            code >= 300.
+
+            Fails with ``NotRetryingDestination`` if we are not yet ready
+            to retry this server.
+
+            Fails with ``FederationDeniedError`` if this destination
+            is not on our federation whitelist
+        """
+
+        response = yield self._request(
+            destination,
+            "DELETE",
+            path,
+            query_bytes=encode_query_args(args),
+            headers_dict={"Content-Type": ["application/json"]},
+            long_retries=long_retries,
+            timeout=timeout,
+            ignore_backoff=ignore_backoff,
+        )
+
+        if 200 <= response.code < 300:
+            # We need to update the transactions table to say it was sent?
+            check_content_type_is_json(response.headers)
+
+        with logcontext.PreserveLoggingContext():
+            body = yield readBody(response)
 
         defer.returnValue(json.loads(body))
 
     @defer.inlineCallbacks
     def get_file(self, destination, path, output_stream, args={},
-                 retry_on_dns_fail=True, max_size=None):
+                 retry_on_dns_fail=True, max_size=None,
+                 ignore_backoff=False):
         """GETs a file from a given homeserver
         Args:
             destination (str): The remote server to send the HTTP request to.
             path (str): The HTTP path to GET.
             output_stream (file): File to write the response body to.
             args (dict): Optional dictionary used to create the query string.
+            ignore_backoff (bool): true to ignore the historical backoff data
+                and try the request anyway.
         Returns:
-            A (int,dict) tuple of the file length and a dict of the response
-            headers.
+            Deferred: resolves with an (int,dict) tuple of the file length and
+            a dict of the response headers.
+
+            Fails with ``HTTPRequestException`` if we get an HTTP response code
+            >= 300
+
+            Fails with ``NotRetryingDestination`` if we are not yet ready
+            to retry this server.
+
+            Fails with ``FederationDeniedError`` if this destination
+            is not on our federation whitelist
         """
 
         encoded_args = {}
@@ -419,29 +558,30 @@ class MatrixFederationHttpClient(object):
             encoded_args[k] = [v.encode("UTF-8") for v in vs]
 
         query_bytes = urllib.urlencode(encoded_args, True)
-        logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
+        logger.debug("Query bytes: %s Retry DNS: %s", query_bytes, retry_on_dns_fail)
 
         def body_callback(method, url_bytes, headers_dict):
             self.sign_request(destination, method, url_bytes, headers_dict)
             return None
 
-        response = yield self._create_request(
-            destination.encode("ascii"),
+        response = yield self._request(
+            destination,
             "GET",
-            path.encode("ascii"),
+            path,
             query_bytes=query_bytes,
             body_callback=body_callback,
-            retry_on_dns_fail=retry_on_dns_fail
+            retry_on_dns_fail=retry_on_dns_fail,
+            ignore_backoff=ignore_backoff,
         )
 
         headers = dict(response.headers.getAllRawHeaders())
 
         try:
-            length = yield preserve_context_over_fn(
-                _readBodyToFile,
-                response, output_stream, max_size
-            )
-        except:
+            with logcontext.PreserveLoggingContext():
+                length = yield _readBodyToFile(
+                    response, output_stream, max_size
+                )
+        except Exception:
             logger.exception("Failed to download body")
             raise
 
@@ -506,12 +646,14 @@ class _JsonProducer(object):
 
 def _flatten_response_never_received(e):
     if hasattr(e, "reasons"):
-        return ", ".join(
+        reasons = ", ".join(
             _flatten_response_never_received(f.value)
             for f in e.reasons
         )
+
+        return "%s:[%s]" % (type(e).__name__, reasons)
     else:
-        return "%s: %s" % (type(e).__name__, e.message,)
+        return repr(e)
 
 
 def check_content_type_is_json(headers):
@@ -538,3 +680,15 @@ def check_content_type_is_json(headers):
         raise RuntimeError(
             "Content-Type not application/json: was '%s'" % c_type
         )
+
+
+def encode_query_args(args):
+    encoded_args = {}
+    for k, vs in args.items():
+        if isinstance(vs, basestring):
+            vs = [vs]
+        encoded_args[k] = [v.encode("UTF-8") for v in vs]
+
+    query_bytes = urllib.urlencode(encoded_args, True)
+
+    return query_bytes
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 14715878c5..55b9ad5251 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -28,6 +29,7 @@ from canonicaljson import (
 )
 
 from twisted.internet import defer
+from twisted.python import failure
 from twisted.web import server, resource
 from twisted.web.server import NOT_DONE_YET
 from twisted.web.util import redirectTo
@@ -35,42 +37,86 @@ from twisted.web.util import redirectTo
 import collections
 import logging
 import urllib
-import ujson
+import simplejson
 
 logger = logging.getLogger(__name__)
 
 metrics = synapse.metrics.get_metrics_for(__name__)
 
-incoming_requests_counter = metrics.register_counter(
-    "requests",
+# total number of responses served, split by method/servlet/tag
+response_count = metrics.register_counter(
+    "response_count",
     labels=["method", "servlet", "tag"],
+    alternative_names=(
+        # the following are all deprecated aliases for the same metric
+        metrics.name_prefix + x for x in (
+            "_requests",
+            "_response_time:count",
+            "_response_ru_utime:count",
+            "_response_ru_stime:count",
+            "_response_db_txn_count:count",
+            "_response_db_txn_duration:count",
+        )
+    )
+)
+
+requests_counter = metrics.register_counter(
+    "requests_received",
+    labels=["method", "servlet", ],
 )
+
 outgoing_responses_counter = metrics.register_counter(
     "responses",
     labels=["method", "code"],
 )
 
-response_timer = metrics.register_distribution(
-    "response_time",
-    labels=["method", "servlet", "tag"]
+response_timer = metrics.register_counter(
+    "response_time_seconds",
+    labels=["method", "servlet", "tag"],
+    alternative_names=(
+        metrics.name_prefix + "_response_time:total",
+    ),
 )
 
-response_ru_utime = metrics.register_distribution(
-    "response_ru_utime", labels=["method", "servlet", "tag"]
+response_ru_utime = metrics.register_counter(
+    "response_ru_utime_seconds", labels=["method", "servlet", "tag"],
+    alternative_names=(
+        metrics.name_prefix + "_response_ru_utime:total",
+    ),
 )
 
-response_ru_stime = metrics.register_distribution(
-    "response_ru_stime", labels=["method", "servlet", "tag"]
+response_ru_stime = metrics.register_counter(
+    "response_ru_stime_seconds", labels=["method", "servlet", "tag"],
+    alternative_names=(
+        metrics.name_prefix + "_response_ru_stime:total",
+    ),
 )
 
-response_db_txn_count = metrics.register_distribution(
-    "response_db_txn_count", labels=["method", "servlet", "tag"]
+response_db_txn_count = metrics.register_counter(
+    "response_db_txn_count", labels=["method", "servlet", "tag"],
+    alternative_names=(
+        metrics.name_prefix + "_response_db_txn_count:total",
+    ),
 )
 
-response_db_txn_duration = metrics.register_distribution(
-    "response_db_txn_duration", labels=["method", "servlet", "tag"]
+# seconds spent waiting for db txns, excluding scheduling time, when processing
+# this request
+response_db_txn_duration = metrics.register_counter(
+    "response_db_txn_duration_seconds", labels=["method", "servlet", "tag"],
+    alternative_names=(
+        metrics.name_prefix + "_response_db_txn_duration:total",
+    ),
 )
 
+# seconds spent waiting for a db connection, when processing this request
+response_db_sched_duration = metrics.register_counter(
+    "response_db_sched_duration_seconds", labels=["method", "servlet", "tag"]
+)
+
+# size in bytes of the response written
+response_size = metrics.register_counter(
+    "response_size", labels=["method", "servlet", "tag"]
+)
 
 _next_request_id = 0
 
@@ -106,7 +152,12 @@ def wrap_request_handler(request_handler, include_metrics=False):
         with LoggingContext(request_id) as request_context:
             with Measure(self.clock, "wrapped_request_handler"):
                 request_metrics = RequestMetrics()
-                request_metrics.start(self.clock, name=self.__class__.__name__)
+                # we start the request metrics timer here with an initial stab
+                # at the servlet name. For most requests that name will be
+                # JsonResource (or a subclass), and JsonResource._async_render
+                # will update it once it picks a servlet.
+                servlet_name = self.__class__.__name__
+                request_metrics.start(self.clock, name=servlet_name)
 
                 request_context.request = request_id
                 with request.processing():
@@ -115,6 +166,7 @@ def wrap_request_handler(request_handler, include_metrics=False):
                             if include_metrics:
                                 yield request_handler(self, request, request_metrics)
                             else:
+                                requests_counter.inc(request.method, servlet_name)
                                 yield request_handler(self, request)
                     except CodeMessageException as e:
                         code = e.code
@@ -130,13 +182,18 @@ def wrap_request_handler(request_handler, include_metrics=False):
                             pretty_print=_request_user_agent_is_curl(request),
                             version_string=self.version_string,
                         )
-                    except:
-                        logger.exception(
-                            "Failed handle request %s.%s on %r: %r",
+                    except Exception:
+                        # failure.Failure() fishes the original Failure out
+                        # of our stack, and thus gives us a sensible stack
+                        # trace.
+                        f = failure.Failure()
+                        logger.error(
+                            "Failed handle request %s.%s on %r: %r: %s",
                             request_handler.__module__,
                             request_handler.__name__,
                             self,
-                            request
+                            request,
+                            f.getTraceback().rstrip(),
                         )
                         respond_with_json(
                             request,
@@ -145,7 +202,9 @@ def wrap_request_handler(request_handler, include_metrics=False):
                                 "error": "Internal server error",
                                 "errcode": Codes.UNKNOWN,
                             },
-                            send_cors=True
+                            send_cors=True,
+                            pretty_print=_request_user_agent_is_curl(request),
+                            version_string=self.version_string,
                         )
                     finally:
                         try:
@@ -183,7 +242,7 @@ class JsonResource(HttpServer, resource.Resource):
     """ This implements the HttpServer interface and provides JSON support for
     Resources.
 
-    Register callbacks via register_path()
+    Register callbacks via register_paths()
 
     Callbacks can return a tuple of status code and a dict in which case the
     the dict will automatically be sent to the client as a JSON object.
@@ -230,57 +289,62 @@ class JsonResource(HttpServer, resource.Resource):
             This checks if anyone has registered a callback for that method and
             path.
         """
-        if request.method == "OPTIONS":
-            self._send_response(request, 200, {})
-            return
+        callback, group_dict = self._get_handler_for_request(request)
 
-        # Loop through all the registered callbacks to check if the method
-        # and path regex match
-        for path_entry in self.path_regexs.get(request.method, []):
-            m = path_entry.pattern.match(request.path)
-            if not m:
-                continue
+        servlet_instance = getattr(callback, "__self__", None)
+        if servlet_instance is not None:
+            servlet_classname = servlet_instance.__class__.__name__
+        else:
+            servlet_classname = "%r" % callback
 
-            # We found a match! Trigger callback and then return the
-            # returned response. We pass both the request and any
-            # matched groups from the regex to the callback.
+        request_metrics.name = servlet_classname
+        requests_counter.inc(request.method, servlet_classname)
 
-            callback = path_entry.callback
+        # Now trigger the callback. If it returns a response, we send it
+        # here. If it throws an exception, that is handled by the wrapper
+        # installed by @request_handler.
 
-            kwargs = intern_dict({
-                name: urllib.unquote(value).decode("UTF-8") if value else value
-                for name, value in m.groupdict().items()
-            })
+        kwargs = intern_dict({
+            name: urllib.unquote(value).decode("UTF-8") if value else value
+            for name, value in group_dict.items()
+        })
 
-            callback_return = yield callback(request, **kwargs)
-            if callback_return is not None:
-                code, response = callback_return
-                self._send_response(request, code, response)
+        callback_return = yield callback(request, **kwargs)
+        if callback_return is not None:
+            code, response = callback_return
+            self._send_response(request, code, response)
 
-            servlet_instance = getattr(callback, "__self__", None)
-            if servlet_instance is not None:
-                servlet_classname = servlet_instance.__class__.__name__
-            else:
-                servlet_classname = "%r" % callback
+    def _get_handler_for_request(self, request):
+        """Finds a callback method to handle the given request
 
-            request_metrics.name = servlet_classname
+        Args:
+            request (twisted.web.http.Request):
+
+        Returns:
+            Tuple[Callable, dict[str, str]]: callback method, and the dict
+                mapping keys to path components as specified in the handler's
+                path match regexp.
 
-            return
+                The callback will normally be a method registered via
+                register_paths, so will return (possibly via Deferred) either
+                None, or a tuple of (http code, response body).
+        """
+        if request.method == b"OPTIONS":
+            return _options_handler, {}
+
+        # Loop through all the registered callbacks to check if the method
+        # and path regex match
+        for path_entry in self.path_regexs.get(request.method, []):
+            m = path_entry.pattern.match(request.path)
+            if m:
+                # We found a match!
+                return path_entry.callback, m.groupdict()
 
         # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
-        raise UnrecognizedRequestError()
+        return _unrecognised_request_handler, {}
 
     def _send_response(self, request, code, response_json_object,
                        response_code_message=None):
-        # could alternatively use request.notifyFinish() and flip a flag when
-        # the Deferred fires, but since the flag is RIGHT THERE it seems like
-        # a waste.
-        if request._disconnected:
-            logger.warn(
-                "Not sending response to request %s, already disconnected.",
-                request)
-            return
-
         outgoing_responses_counter.inc(request.method, str(code))
 
         # TODO: Only enable CORS for the requests that need it.
@@ -294,6 +358,34 @@ class JsonResource(HttpServer, resource.Resource):
         )
 
 
+def _options_handler(request):
+    """Request handler for OPTIONS requests
+
+    This is a request handler suitable for return from
+    _get_handler_for_request. It returns a 200 and an empty body.
+
+    Args:
+        request (twisted.web.http.Request):
+
+    Returns:
+        Tuple[int, dict]: http code, response body.
+    """
+    return 200, {}
+
+
+def _unrecognised_request_handler(request):
+    """Request handler for unrecognised requests
+
+    This is a request handler suitable for return from
+    _get_handler_for_request. It actually just raises an
+    UnrecognizedRequestError.
+
+    Args:
+        request (twisted.web.http.Request):
+    """
+    raise UnrecognizedRequestError()
+
+
 class RequestMetrics(object):
     def start(self, clock, name):
         self.start = clock.time_msec()
@@ -314,7 +406,7 @@ class RequestMetrics(object):
                 )
                 return
 
-        incoming_requests_counter.inc(request.method, self.name, tag)
+        response_count.inc(request.method, self.name, tag)
 
         response_timer.inc_by(
             clock.time_msec() - self.start, request.method,
@@ -333,9 +425,14 @@ class RequestMetrics(object):
             context.db_txn_count, request.method, self.name, tag
         )
         response_db_txn_duration.inc_by(
-            context.db_txn_duration, request.method, self.name, tag
+            context.db_txn_duration_ms / 1000., request.method, self.name, tag
+        )
+        response_db_sched_duration.inc_by(
+            context.db_sched_duration_ms / 1000., request.method, self.name, tag
         )
 
+        response_size.inc_by(request.sentLength, request.method, self.name, tag)
+
 
 class RootRedirect(resource.Resource):
     """Redirects the root '/' path to another path."""
@@ -356,14 +453,22 @@ class RootRedirect(resource.Resource):
 def respond_with_json(request, code, json_object, send_cors=False,
                       response_code_message=None, pretty_print=False,
                       version_string="", canonical_json=True):
+    # could alternatively use request.notifyFinish() and flip a flag when
+    # the Deferred fires, but since the flag is RIGHT THERE it seems like
+    # a waste.
+    if request._disconnected:
+        logger.warn(
+            "Not sending response to request %s, already disconnected.",
+            request)
+        return
+
     if pretty_print:
         json_bytes = encode_pretty_printed_json(json_object) + "\n"
     else:
         if canonical_json or synapse.events.USE_FROZEN_DICTS:
             json_bytes = encode_canonical_json(json_object)
         else:
-            # ujson doesn't like frozen_dicts.
-            json_bytes = ujson.dumps(json_object, ensure_ascii=False)
+            json_bytes = simplejson.dumps(json_object)
 
     return respond_with_json_bytes(
         request, code, json_bytes,
@@ -390,6 +495,7 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
     request.setHeader(b"Content-Type", b"application/json")
     request.setHeader(b"Server", version_string)
     request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),))
+    request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
 
     if send_cors:
         set_cors_headers(request)
@@ -412,7 +518,7 @@ def set_cors_headers(request):
     )
     request.setHeader(
         "Access-Control-Allow-Headers",
-        "Origin, X-Requested-With, Content-Type, Accept"
+        "Origin, X-Requested-With, Content-Type, Accept, Authorization"
     )
 
 
@@ -437,9 +543,9 @@ def finish_request(request):
 
 def _request_user_agent_is_curl(request):
     user_agents = request.requestHeaders.getRawHeaders(
-        "User-Agent", default=[]
+        b"User-Agent", default=[]
     )
     for user_agent in user_agents:
-        if "curl" in user_agent:
+        if b"curl" in user_agent:
             return True
     return False
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 8c22d6f00f..ef8e62901b 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -48,7 +48,7 @@ def parse_integer_from_args(args, name, default=None, required=False):
     if name in args:
         try:
             return int(args[name][0])
-        except:
+        except Exception:
             message = "Query parameter %r must be an integer" % (name,)
             raise SynapseError(400, message)
     else:
@@ -88,7 +88,7 @@ def parse_boolean_from_args(args, name, default=None, required=False):
                 "true": True,
                 "false": False,
             }[args[name][0]]
-        except:
+        except Exception:
             message = (
                 "Boolean query parameter %r must be one of"
                 " ['true', 'false']"
@@ -148,11 +148,13 @@ def parse_string_from_args(args, name, default=None, required=False,
             return default
 
 
-def parse_json_value_from_request(request):
+def parse_json_value_from_request(request, allow_empty_body=False):
     """Parse a JSON value from the body of a twisted HTTP request.
 
     Args:
         request: the twisted HTTP request.
+        allow_empty_body (bool): if True, an empty body will be accepted and
+            turned into None
 
     Returns:
         The JSON value.
@@ -162,28 +164,39 @@ def parse_json_value_from_request(request):
     """
     try:
         content_bytes = request.content.read()
-    except:
+    except Exception:
         raise SynapseError(400, "Error reading JSON content.")
 
+    if not content_bytes and allow_empty_body:
+        return None
+
     try:
         content = simplejson.loads(content_bytes)
-    except simplejson.JSONDecodeError:
+    except Exception as e:
+        logger.warn("Unable to parse JSON: %s", e)
         raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
 
     return content
 
 
-def parse_json_object_from_request(request):
+def parse_json_object_from_request(request, allow_empty_body=False):
     """Parse a JSON object from the body of a twisted HTTP request.
 
     Args:
         request: the twisted HTTP request.
+        allow_empty_body (bool): if True, an empty body will be accepted and
+            turned into an empty dict.
 
     Raises:
         SynapseError if the request body couldn't be decoded as JSON or
             if it wasn't a JSON object.
     """
-    content = parse_json_value_from_request(request)
+    content = parse_json_value_from_request(
+        request, allow_empty_body=allow_empty_body,
+    )
+
+    if allow_empty_body and content is None:
+        return {}
 
     if type(content) != dict:
         message = "Content must be a JSON object."
@@ -192,6 +205,16 @@ def parse_json_object_from_request(request):
     return content
 
 
+def assert_params_in_request(body, required):
+    absent = []
+    for k in required:
+        if k not in body:
+            absent.append(k)
+
+    if len(absent) > 0:
+        raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+
+
 class RestServlet(object):
 
     """ A Synapse REST Servlet.
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 4b09d7ee66..c8b46e1af2 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -20,7 +20,7 @@ import logging
 import re
 import time
 
-ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
+ACCESS_TOKEN_RE = re.compile(br'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
 
 
 class SynapseRequest(Request):
@@ -43,12 +43,12 @@ class SynapseRequest(Request):
 
     def get_redacted_uri(self):
         return ACCESS_TOKEN_RE.sub(
-            r'\1<redacted>\3',
+            br'\1<redacted>\3',
             self.uri
         )
 
     def get_user_agent(self):
-        return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
+        return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
 
     def started_processing(self):
         self.site.access_logger.info(
@@ -66,14 +66,15 @@ class SynapseRequest(Request):
             context = LoggingContext.current_context()
             ru_utime, ru_stime = context.get_resource_usage()
             db_txn_count = context.db_txn_count
-            db_txn_duration = context.db_txn_duration
-        except:
+            db_txn_duration_ms = context.db_txn_duration_ms
+            db_sched_duration_ms = context.db_sched_duration_ms
+        except Exception:
             ru_utime, ru_stime = (0, 0)
-            db_txn_count, db_txn_duration = (0, 0)
+            db_txn_count, db_txn_duration_ms = (0, 0)
 
         self.site.access_logger.info(
             "%s - %s - {%s}"
-            " Processed request: %dms (%dms, %dms) (%dms/%d)"
+            " Processed request: %dms (%dms, %dms) (%dms/%dms/%d)"
             " %sB %s \"%s %s %s\" \"%s\"",
             self.getClientIP(),
             self.site.site_tag,
@@ -81,7 +82,8 @@ class SynapseRequest(Request):
             int(time.time() * 1000) - self.start_time,
             int(ru_utime * 1000),
             int(ru_stime * 1000),
-            int(db_txn_duration * 1000),
+            db_sched_duration_ms,
+            db_txn_duration_ms,
             int(db_txn_count),
             self.sentLength,
             self.code,
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 2265e6e8d6..e3b831db67 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -17,12 +17,13 @@ import logging
 import functools
 import time
 import gc
+import platform
 
 from twisted.internet import reactor
 
 from .metric import (
     CounterMetric, CallbackMetric, DistributionMetric, CacheMetric,
-    MemoryUsageMetric,
+    MemoryUsageMetric, GaugeMetric,
 )
 from .process_collector import register_process_collector
 
@@ -30,6 +31,7 @@ from .process_collector import register_process_collector
 logger = logging.getLogger(__name__)
 
 
+running_on_pypy = platform.python_implementation() == 'PyPy'
 all_metrics = []
 all_collectors = []
 
@@ -57,15 +59,38 @@ class Metrics(object):
         return metric
 
     def register_counter(self, *args, **kwargs):
+        """
+        Returns:
+            CounterMetric
+        """
         return self._register(CounterMetric, *args, **kwargs)
 
+    def register_gauge(self, *args, **kwargs):
+        """
+        Returns:
+            GaugeMetric
+        """
+        return self._register(GaugeMetric, *args, **kwargs)
+
     def register_callback(self, *args, **kwargs):
+        """
+        Returns:
+            CallbackMetric
+        """
         return self._register(CallbackMetric, *args, **kwargs)
 
     def register_distribution(self, *args, **kwargs):
+        """
+        Returns:
+            DistributionMetric
+        """
         return self._register(DistributionMetric, *args, **kwargs)
 
     def register_cache(self, *args, **kwargs):
+        """
+        Returns:
+            CacheMetric
+        """
         return self._register(CacheMetric, *args, **kwargs)
 
 
@@ -126,6 +151,32 @@ reactor_metrics = get_metrics_for("python.twisted.reactor")
 tick_time = reactor_metrics.register_distribution("tick_time")
 pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
 
+synapse_metrics = get_metrics_for("synapse")
+
+# Used to track where various components have processed in the event stream,
+# e.g. federation sending, appservice sending, etc.
+event_processing_positions = synapse_metrics.register_gauge(
+    "event_processing_positions", labels=["name"],
+)
+
+# Used to track the current max events stream position
+event_persisted_position = synapse_metrics.register_gauge(
+    "event_persisted_position",
+)
+
+# Used to track the received_ts of the last event processed by various
+# components
+event_processing_last_ts = synapse_metrics.register_gauge(
+    "event_processing_last_ts", labels=["name"],
+)
+
+# Used to track the lag processing events. This is the time difference
+# between the last processed event's received_ts and the time it was
+# finished being processed.
+event_processing_lag = synapse_metrics.register_gauge(
+    "event_processing_lag", labels=["name"],
+)
+
 
 def runUntilCurrentTimer(func):
 
@@ -146,13 +197,21 @@ def runUntilCurrentTimer(func):
             num_pending += 1
 
         num_pending += len(reactor.threadCallQueue)
-
         start = time.time() * 1000
         ret = func(*args, **kwargs)
         end = time.time() * 1000
+
+        # record the amount of wallclock time spent running pending calls.
+        # This is a proxy for the actual amount of time between reactor polls,
+        # since about 25% of time is actually spent running things triggered by
+        # I/O events, but that is harder to capture without rewriting half the
+        # reactor.
         tick_time.inc_by(end - start)
         pending_calls_metric.inc_by(num_pending)
 
+        if running_on_pypy:
+            return ret
+
         # Check if we need to do a manual GC (since its been disabled), and do
         # one if necessary.
         threshold = gc.get_threshold()
@@ -185,6 +244,7 @@ try:
 
     # We manually run the GC each reactor tick so that we can get some metrics
     # about time spent doing GC,
-    gc.disable()
+    if not running_on_pypy:
+        gc.disable()
 except AttributeError:
     pass
diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py
index e87b2b80a7..fbba94e633 100644
--- a/synapse/metrics/metric.py
+++ b/synapse/metrics/metric.py
@@ -15,18 +15,39 @@
 
 
 from itertools import chain
+import logging
+import re
 
+logger = logging.getLogger(__name__)
 
-# TODO(paul): I can't believe Python doesn't have one of these
-def map_concat(func, items):
-    # flatten a list-of-lists
-    return list(chain.from_iterable(map(func, items)))
+
+def flatten(items):
+    """Flatten a list of lists
+
+    Args:
+        items: iterable[iterable[X]]
+
+    Returns:
+        list[X]: flattened list
+    """
+    return list(chain.from_iterable(items))
 
 
 class BaseMetric(object):
+    """Base class for metrics which report a single value per label set
+    """
 
-    def __init__(self, name, labels=[]):
-        self.name = name
+    def __init__(self, name, labels=[], alternative_names=[]):
+        """
+        Args:
+            name (str): principal name for this metric
+            labels (list(str)): names of the labels which will be reported
+                for this metric
+            alternative_names (iterable(str)): list of alternative names for
+                 this metric. This can be useful to provide a migration path
+                when renaming metrics.
+        """
+        self._names = [name] + list(alternative_names)
         self.labels = labels  # OK not to clone as we never write it
 
     def dimension(self):
@@ -36,8 +57,7 @@ class BaseMetric(object):
         return not len(self.labels)
 
     def _render_labelvalue(self, value):
-        # TODO: some kind of value escape
-        return '"%s"' % (value)
+        return '"%s"' % (_escape_label_value(value),)
 
     def _render_key(self, values):
         if self.is_scalar():
@@ -47,19 +67,60 @@ class BaseMetric(object):
                       for k, v in zip(self.labels, values)])
         )
 
+    def _render_for_labels(self, label_values, value):
+        """Render this metric for a single set of labels
+
+        Args:
+            label_values (list[str]): values for each of the labels
+            value: value of the metric at with these labels
+
+        Returns:
+            iterable[str]: rendered metric
+        """
+        rendered_labels = self._render_key(label_values)
+        return (
+            "%s%s %.12g" % (name, rendered_labels, value)
+            for name in self._names
+        )
+
+    def render(self):
+        """Render this metric
+
+        Each metric is rendered as:
+
+            name{label1="val1",label2="val2"} value
+
+        https://prometheus.io/docs/instrumenting/exposition_formats/#text-format-details
+
+        Returns:
+            iterable[str]: rendered metrics
+        """
+        raise NotImplementedError()
+
 
 class CounterMetric(BaseMetric):
     """The simplest kind of metric; one that stores a monotonically-increasing
-    integer that counts events."""
+    value that counts events or running totals.
+
+    Example use cases for Counters:
+    - Number of requests processed
+    - Number of items that were inserted into a queue
+    - Total amount of data that a system has processed
+    Counters can only go up (and be reset when the process restarts).
+    """
 
     def __init__(self, *args, **kwargs):
         super(CounterMetric, self).__init__(*args, **kwargs)
 
+        # dict[list[str]]: value for each set of label values. the keys are the
+        # label values, in the same order as the labels in self.labels.
+        #
+        # (if the metric is a scalar, the (single) key is the empty tuple).
         self.counts = {}
 
         # Scalar metrics are never empty
         if self.is_scalar():
-            self.counts[()] = 0
+            self.counts[()] = 0.
 
     def inc_by(self, incr, *values):
         if len(values) != self.dimension():
@@ -77,11 +138,41 @@ class CounterMetric(BaseMetric):
     def inc(self, *values):
         self.inc_by(1, *values)
 
-    def render_item(self, k):
-        return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])]
+    def render(self):
+        return flatten(
+            self._render_for_labels(k, self.counts[k])
+            for k in sorted(self.counts.keys())
+        )
+
+
+class GaugeMetric(BaseMetric):
+    """A metric that can go up or down
+    """
+
+    def __init__(self, *args, **kwargs):
+        super(GaugeMetric, self).__init__(*args, **kwargs)
+
+        # dict[list[str]]: value for each set of label values. the keys are the
+        # label values, in the same order as the labels in self.labels.
+        #
+        # (if the metric is a scalar, the (single) key is the empty tuple).
+        self.guages = {}
+
+    def set(self, v, *values):
+        if len(values) != self.dimension():
+            raise ValueError(
+                "Expected as many values to inc() as labels (%d)" % (self.dimension())
+            )
+
+        # TODO: should assert that the tag values are all strings
+
+        self.guages[values] = v
 
     def render(self):
-        return map_concat(self.render_item, sorted(self.counts.keys()))
+        return flatten(
+            self._render_for_labels(k, self.guages[k])
+            for k in sorted(self.guages.keys())
+        )
 
 
 class CallbackMetric(BaseMetric):
@@ -95,13 +186,19 @@ class CallbackMetric(BaseMetric):
         self.callback = callback
 
     def render(self):
-        value = self.callback()
+        try:
+            value = self.callback()
+        except Exception:
+            logger.exception("Failed to render %s", self.name)
+            return ["# FAILED to render " + self.name]
 
         if self.is_scalar():
-            return ["%s %.12g" % (self.name, value)]
+            return list(self._render_for_labels([], value))
 
-        return ["%s%s %.12g" % (self.name, self._render_key(k), value[k])
-                for k in sorted(value.keys())]
+        return flatten(
+            self._render_for_labels(k, value[k])
+            for k in sorted(value.keys())
+        )
 
 
 class DistributionMetric(object):
@@ -126,7 +223,9 @@ class DistributionMetric(object):
 
 
 class CacheMetric(object):
-    __slots__ = ("name", "cache_name", "hits", "misses", "size_callback")
+    __slots__ = (
+        "name", "cache_name", "hits", "misses", "evicted_size", "size_callback",
+    )
 
     def __init__(self, name, size_callback, cache_name):
         self.name = name
@@ -134,6 +233,7 @@ class CacheMetric(object):
 
         self.hits = 0
         self.misses = 0
+        self.evicted_size = 0
 
         self.size_callback = size_callback
 
@@ -143,6 +243,9 @@ class CacheMetric(object):
     def inc_misses(self):
         self.misses += 1
 
+    def inc_evictions(self, size=1):
+        self.evicted_size += size
+
     def render(self):
         size = self.size_callback()
         hits = self.hits
@@ -152,6 +255,9 @@ class CacheMetric(object):
             """%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
             """%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
             """%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
+            """%s:evicted_size{name="%s"} %d""" % (
+                self.name, self.cache_name, self.evicted_size
+            ),
         ]
 
 
@@ -193,3 +299,29 @@ class MemoryUsageMetric(object):
             "process_psutil_rss:total %d" % sum_rss,
             "process_psutil_rss:count %d" % len_rss,
         ]
+
+
+def _escape_character(m):
+    """Replaces a single character with its escape sequence.
+
+    Args:
+        m (re.MatchObject): A match object whose first group is the single
+            character to replace
+
+    Returns:
+        str
+    """
+    c = m.group(1)
+    if c == "\\":
+        return "\\\\"
+    elif c == "\"":
+        return "\\\""
+    elif c == "\n":
+        return "\\n"
+    return c
+
+
+def _escape_label_value(value):
+    """Takes a label value and escapes quotes, newlines and backslashes
+    """
+    return re.sub(r"([\n\"\\])", _escape_character, value)
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
new file mode 100644
index 0000000000..097c844d31
--- /dev/null
+++ b/synapse/module_api/__init__.py
@@ -0,0 +1,123 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from synapse.types import UserID
+
+
+class ModuleApi(object):
+    """A proxy object that gets passed to password auth providers so they
+    can register new users etc if necessary.
+    """
+    def __init__(self, hs, auth_handler):
+        self.hs = hs
+
+        self._store = hs.get_datastore()
+        self._auth = hs.get_auth()
+        self._auth_handler = auth_handler
+
+    def get_user_by_req(self, req, allow_guest=False):
+        """Check the access_token provided for a request
+
+        Args:
+            req (twisted.web.server.Request): Incoming HTTP request
+            allow_guest (bool): True if guest users should be allowed. If this
+                is False, and the access token is for a guest user, an
+                AuthError will be thrown
+        Returns:
+            twisted.internet.defer.Deferred[synapse.types.Requester]:
+                the requester for this request
+        Raises:
+            synapse.api.errors.AuthError: if no user by that token exists,
+                or the token is invalid.
+        """
+        return self._auth.get_user_by_req(req, allow_guest)
+
+    def get_qualified_user_id(self, username):
+        """Qualify a user id, if necessary
+
+        Takes a user id provided by the user and adds the @ and :domain to
+        qualify it, if necessary
+
+        Args:
+            username (str): provided user id
+
+        Returns:
+            str: qualified @user:id
+        """
+        if username.startswith('@'):
+            return username
+        return UserID(username, self.hs.hostname).to_string()
+
+    def check_user_exists(self, user_id):
+        """Check if user exists.
+
+        Args:
+            user_id (str): Complete @user:id
+
+        Returns:
+            Deferred[str|None]: Canonical (case-corrected) user_id, or None
+               if the user is not registered.
+        """
+        return self._auth_handler.check_user_exists(user_id)
+
+    def register(self, localpart):
+        """Registers a new user with given localpart
+
+        Returns:
+            Deferred: a 2-tuple of (user_id, access_token)
+        """
+        reg = self.hs.get_handlers().registration_handler
+        return reg.register(localpart=localpart)
+
+    @defer.inlineCallbacks
+    def invalidate_access_token(self, access_token):
+        """Invalidate an access token for a user
+
+        Args:
+            access_token(str): access token
+
+        Returns:
+            twisted.internet.defer.Deferred - resolves once the access token
+               has been removed.
+
+        Raises:
+            synapse.api.errors.AuthError: the access token is invalid
+        """
+        # see if the access token corresponds to a device
+        user_info = yield self._auth.get_user_by_access_token(access_token)
+        device_id = user_info.get("device_id")
+        user_id = user_info["user"].to_string()
+        if device_id:
+            # delete the device, which will also delete its access tokens
+            yield self.hs.get_device_handler().delete_device(user_id, device_id)
+        else:
+            # no associated device. Just delete the access token.
+            yield self._auth_handler.delete_access_token(access_token)
+
+    def run_db_interaction(self, desc, func, *args, **kwargs):
+        """Run a function with a database connection
+
+        Args:
+            desc (str): description for the transaction, for metrics etc
+            func (func): function to be run. Passed a database cursor object
+                as well as *args and **kwargs
+            *args: positional args to be passed to func
+            **kwargs: named args to be passed to func
+
+        Returns:
+            Deferred[object]: result of func
+        """
+        return self._store.runInteraction(desc, func, *args, **kwargs)
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 8051a7a842..8355c7d621 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -14,13 +14,17 @@
 # limitations under the License.
 
 from twisted.internet import defer
+
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import AuthError
+from synapse.handlers.presence import format_user_presence_state
 
-from synapse.util import DeferredTimedOutError
 from synapse.util.logutils import log_function
-from synapse.util.async import ObservableDeferred
-from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
+from synapse.util.async import (
+    ObservableDeferred, add_timeout_to_deferred,
+    DeferredTimeoutError,
+)
+from synapse.util.logcontext import PreserveLoggingContext, run_in_background
 from synapse.util.metrics import Measure
 from synapse.types import StreamToken
 from synapse.visibility import filter_events_for_client
@@ -37,6 +41,10 @@ metrics = synapse.metrics.get_metrics_for(__name__)
 
 notified_events_counter = metrics.register_counter("notified_events")
 
+users_woken_by_stream_counter = metrics.register_counter(
+    "users_woken_by_stream", labels=["stream"]
+)
+
 
 # TODO(paul): Should be shared somewhere
 def count(func, l):
@@ -73,6 +81,13 @@ class _NotifierUserStream(object):
         self.user_id = user_id
         self.rooms = set(rooms)
         self.current_token = current_token
+
+        # The last token for which we should wake up any streams that have a
+        # token that comes before it. This gets updated everytime we get poked.
+        # We start it at the current token since if we get any streams
+        # that have a token from before we have no idea whether they should be
+        # woken up or not, so lets just wake them up.
+        self.last_notified_token = current_token
         self.last_notified_ms = time_now_ms
 
         with PreserveLoggingContext():
@@ -89,9 +104,12 @@ class _NotifierUserStream(object):
         self.current_token = self.current_token.copy_and_advance(
             stream_key, stream_id
         )
+        self.last_notified_token = self.current_token
         self.last_notified_ms = time_now_ms
         noify_deferred = self.notify_deferred
 
+        users_woken_by_stream_counter.inc(stream_key)
+
         with PreserveLoggingContext():
             self.notify_deferred = ObservableDeferred(defer.Deferred())
             noify_deferred.callback(self.current_token)
@@ -113,8 +131,14 @@ class _NotifierUserStream(object):
     def new_listener(self, token):
         """Returns a deferred that is resolved when there is a new token
         greater than the given token.
+
+        Args:
+            token: The token from which we are streaming from, i.e. we shouldn't
+                notify for things that happened before this.
         """
-        if self.current_token.is_after(token):
+        # Immediately wake up stream if something has already since happened
+        # since their last token.
+        if self.last_notified_token.is_after(token):
             return _NotificationListener(defer.succeed(self.current_token))
         else:
             return _NotificationListener(self.notify_deferred.observe())
@@ -123,6 +147,7 @@ class _NotifierUserStream(object):
 class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
     def __nonzero__(self):
         return bool(self.events)
+    __bool__ = __nonzero__  # python3
 
 
 class Notifier(object):
@@ -142,6 +167,8 @@ class Notifier(object):
         self.store = hs.get_datastore()
         self.pending_new_room_events = []
 
+        self.replication_callbacks = []
+
         self.clock = hs.get_clock()
         self.appservice_handler = hs.get_application_service_handler()
 
@@ -181,7 +208,12 @@ class Notifier(object):
             lambda: len(self.user_to_user_stream),
         )
 
-    @preserve_fn
+    def add_replication_callback(self, cb):
+        """Add a callback that will be called when some new data is available.
+        Callback is not given any arguments.
+        """
+        self.replication_callbacks.append(cb)
+
     def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
                           extra_users=[]):
         """ Used by handlers to inform the notifier something has happened
@@ -195,15 +227,13 @@ class Notifier(object):
         until all previous events have been persisted before notifying
         the client streams.
         """
-        with PreserveLoggingContext():
-            self.pending_new_room_events.append((
-                room_stream_id, event, extra_users
-            ))
-            self._notify_pending_new_room_events(max_room_stream_id)
+        self.pending_new_room_events.append((
+            room_stream_id, event, extra_users
+        ))
+        self._notify_pending_new_room_events(max_room_stream_id)
 
-            self.notify_replication()
+        self.notify_replication()
 
-    @preserve_fn
     def _notify_pending_new_room_events(self, max_room_stream_id):
         """Notify for the room events that were queued waiting for a previous
         event to be persisted.
@@ -221,11 +251,10 @@ class Notifier(object):
             else:
                 self._on_new_room_event(event, room_stream_id, extra_users)
 
-    @preserve_fn
     def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
         """Notify any user streams that are interested in this room event"""
         # poke any interested application service.
-        self.appservice_handler.notify_interested_services(room_stream_id)
+        run_in_background(self._notify_app_services, room_stream_id)
 
         if self.federation_sender:
             self.federation_sender.notify_new_events(room_stream_id)
@@ -239,7 +268,13 @@ class Notifier(object):
             rooms=[event.room_id],
         )
 
-    @preserve_fn
+    @defer.inlineCallbacks
+    def _notify_app_services(self, room_stream_id):
+        try:
+            yield self.appservice_handler.notify_interested_services(room_stream_id)
+        except Exception:
+            logger.exception("Error notifying application services of event")
+
     def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
         """ Used to inform listeners that something has happend event wise.
 
@@ -261,17 +296,15 @@ class Notifier(object):
                 for user_stream in user_streams:
                     try:
                         user_stream.notify(stream_key, new_token, time_now_ms)
-                    except:
+                    except Exception:
                         logger.exception("Failed to notify listener")
 
                 self.notify_replication()
 
-    @preserve_fn
     def on_new_replication_data(self):
         """Used to inform replication listeners that something has happend
         without waking up any of the normal user event streams"""
-        with PreserveLoggingContext():
-            self.notify_replication()
+        self.notify_replication()
 
     @defer.inlineCallbacks
     def wait_for_events(self, user_id, timeout, callback, room_ids=None,
@@ -283,8 +316,7 @@ class Notifier(object):
         if user_stream is None:
             current_token = yield self.event_sources.get_current_token()
             if room_ids is None:
-                rooms = yield self.store.get_rooms_for_user(user_id)
-                room_ids = [room.room_id for room in rooms]
+                room_ids = yield self.store.get_rooms_for_user(user_id)
             user_stream = _NotifierUserStream(
                 user_id=user_id,
                 rooms=room_ids,
@@ -294,40 +326,45 @@ class Notifier(object):
             self._register_with_keys(user_stream)
 
         result = None
+        prev_token = from_token
         if timeout:
             end_time = self.clock.time_msec() + timeout
 
-            prev_token = from_token
             while not result:
                 try:
-                    current_token = user_stream.current_token
-
-                    result = yield callback(prev_token, current_token)
-                    if result:
-                        break
-
                     now = self.clock.time_msec()
                     if end_time <= now:
                         break
 
                     # Now we wait for the _NotifierUserStream to be told there
                     # is a new token.
-                    # We need to supply the token we supplied to callback so
-                    # that we don't miss any current_token updates.
-                    prev_token = current_token
                     listener = user_stream.new_listener(prev_token)
+                    add_timeout_to_deferred(
+                        listener.deferred,
+                        (end_time - now) / 1000.,
+                    )
                     with PreserveLoggingContext():
-                        yield self.clock.time_bound_deferred(
-                            listener.deferred,
-                            time_out=(end_time - now) / 1000.
-                        )
-                except DeferredTimedOutError:
+                        yield listener.deferred
+
+                    current_token = user_stream.current_token
+
+                    result = yield callback(prev_token, current_token)
+                    if result:
+                        break
+
+                    # Update the prev_token to the current_token since nothing
+                    # has happened between the old prev_token and the current_token
+                    prev_token = current_token
+                except DeferredTimeoutError:
                     break
                 except defer.CancelledError:
                     break
-        else:
+
+        if result is None:
+            # This happened if there was no timeout or if the timeout had
+            # already expired.
             current_token = user_stream.current_token
-            result = yield callback(from_token, current_token)
+            result = yield callback(prev_token, current_token)
 
         defer.returnValue(result)
 
@@ -388,6 +425,15 @@ class Notifier(object):
                         new_events,
                         is_peeking=is_peeking,
                     )
+                elif name == "presence":
+                    now = self.clock.time_msec()
+                    new_events[:] = [
+                        {
+                            "type": "m.presence",
+                            "content": format_user_presence_state(event, now),
+                        }
+                        for event in new_events
+                    ]
 
                 events.extend(new_events)
                 end_token = end_token.copy_and_replace(keyname, new_key)
@@ -420,8 +466,7 @@ class Notifier(object):
 
     @defer.inlineCallbacks
     def _get_room_ids(self, user, explicit_room_id):
-        joined_rooms = yield self.store.get_rooms_for_user(user.to_string())
-        joined_room_ids = map(lambda r: r.room_id, joined_rooms)
+        joined_room_ids = yield self.store.get_rooms_for_user(user.to_string())
         if explicit_room_id:
             if explicit_room_id in joined_room_ids:
                 defer.returnValue(([explicit_room_id], True))
@@ -478,6 +523,15 @@ class Notifier(object):
             self.replication_deferred = ObservableDeferred(defer.Deferred())
             deferred.callback(None)
 
+            # the callbacks may well outlast the current request, so we run
+            # them in the sentinel logcontext.
+            #
+            # (ideally it would be up to the callbacks to know if they were
+            # starting off background processes and drop the logcontext
+            # accordingly, but that requires more changes)
+            for cb in self.replication_callbacks:
+                cb()
+
     @defer.inlineCallbacks
     def wait_for_replication(self, callback, timeout):
         """Wait for an event to happen.
@@ -506,13 +560,14 @@ class Notifier(object):
             if end_time <= now:
                 break
 
+            add_timeout_to_deferred(
+                listener.deferred.addTimeout,
+                (end_time - now) / 1000.,
+            )
             try:
                 with PreserveLoggingContext():
-                    yield self.clock.time_bound_deferred(
-                        listener.deferred,
-                        time_out=(end_time - now) / 1000.
-                    )
-            except DeferredTimedOutError:
+                    yield listener.deferred
+            except DeferredTimeoutError:
                 break
             except defer.CancelledError:
                 break
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index 3f75d3f921..8f619a7a1b 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -15,7 +15,7 @@
 
 from twisted.internet import defer
 
-from .bulk_push_rule_evaluator import evaluator_for_event
+from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
 
 from synapse.util.metrics import Measure
 
@@ -24,11 +24,12 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class ActionGenerator:
+class ActionGenerator(object):
     def __init__(self, hs):
         self.hs = hs
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
+        self.bulk_evaluator = BulkPushRuleEvaluator(hs)
         # really we want to get all user ids and all profile tags too,
         # since we want the actions for each profile tag for every user and
         # also actions for a client with no profile tag for each user.
@@ -38,16 +39,7 @@ class ActionGenerator:
 
     @defer.inlineCallbacks
     def handle_push_actions_for_event(self, event, context):
-        with Measure(self.clock, "evaluator_for_event"):
-            bulk_evaluator = yield evaluator_for_event(
-                event, self.hs, self.store, context
-            )
-
         with Measure(self.clock, "action_for_event_by_user"):
-            actions_by_user = yield bulk_evaluator.action_for_event_by_user(
+            yield self.bulk_evaluator.action_for_event_by_user(
                 event, context
             )
-
-        context.push_actions = [
-            (uid, actions) for uid, actions in actions_by_user.items()
-        ]
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 85effdfa46..7a18afe5f9 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -1,4 +1,5 @@
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -238,6 +239,28 @@ BASE_APPEND_OVERRIDE_RULES = [
             }
         ]
     },
+    {
+        'rule_id': 'global/override/.m.rule.roomnotif',
+        'conditions': [
+            {
+                'kind': 'event_match',
+                'key': 'content.body',
+                'pattern': '@room',
+                '_id': '_roomnotif_content',
+            },
+            {
+                'kind': 'sender_notification_permission',
+                'key': 'room',
+                '_id': '_roomnotif_pl',
+            },
+        ],
+        'actions': [
+            'notify', {
+                'set_tweak': 'highlight',
+                'value': True,
+            }
+        ]
+    }
 ]
 
 
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 78b095c903..7c680659b6 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -19,88 +20,166 @@ from twisted.internet import defer
 
 from .push_rule_evaluator import PushRuleEvaluatorForEvent
 
-from synapse.api.constants import EventTypes
-from synapse.visibility import filter_events_for_clients_context
+from synapse.event_auth import get_user_power_level
+from synapse.api.constants import EventTypes, Membership
+from synapse.metrics import get_metrics_for
+from synapse.util.caches import metrics as cache_metrics
+from synapse.util.caches.descriptors import cached
+from synapse.util.async import Linearizer
+from synapse.state import POWER_KEY
+
+from collections import namedtuple
 
 
 logger = logging.getLogger(__name__)
 
 
-@defer.inlineCallbacks
-def evaluator_for_event(event, hs, store, context):
-    rules_by_user = yield store.bulk_get_push_rules_for_room(
-        event, context
-    )
-
-    # if this event is an invite event, we may need to run rules for the user
-    # who's been invited, otherwise they won't get told they've been invited
-    if event.type == 'm.room.member' and event.content['membership'] == 'invite':
-        invited_user = event.state_key
-        if invited_user and hs.is_mine_id(invited_user):
-            has_pusher = yield store.user_has_pusher(invited_user)
-            if has_pusher:
-                rules_by_user = dict(rules_by_user)
-                rules_by_user[invited_user] = yield store.get_push_rules_for_user(
-                    invited_user
-                )
+rules_by_room = {}
 
-    defer.returnValue(BulkPushRuleEvaluator(
-        event.room_id, rules_by_user, store
-    ))
+push_metrics = get_metrics_for(__name__)
 
+push_rules_invalidation_counter = push_metrics.register_counter(
+    "push_rules_invalidation_counter"
+)
+push_rules_state_size_counter = push_metrics.register_counter(
+    "push_rules_state_size_counter"
+)
 
-class BulkPushRuleEvaluator:
-    """
-    Runs push rules for all users in a room.
-    This is faster than running PushRuleEvaluator for each user because it
-    fetches all the rules for all the users in one (batched) db query
-    rather than doing multiple queries per-user. It currently uses
-    the same logic to run the actual rules, but could be optimised further
-    (see https://matrix.org/jira/browse/SYN-562)
+# Measures whether we use the fast path of using state deltas, or if we have to
+# recalculate from scratch
+push_rules_delta_state_cache_metric = cache_metrics.register_cache(
+    "cache",
+    size_callback=lambda: 0,  # Meaningless size, as this isn't a cache that stores values
+    cache_name="push_rules_delta_state_cache_metric",
+)
+
+
+class BulkPushRuleEvaluator(object):
+    """Calculates the outcome of push rules for an event for all users in the
+    room at once.
     """
-    def __init__(self, room_id, rules_by_user, store):
-        self.room_id = room_id
-        self.rules_by_user = rules_by_user
-        self.store = store
+
+    def __init__(self, hs):
+        self.hs = hs
+        self.store = hs.get_datastore()
+        self.auth = hs.get_auth()
+
+        self.room_push_rule_cache_metrics = cache_metrics.register_cache(
+            "cache",
+            size_callback=lambda: 0,  # There's not good value for this
+            cache_name="room_push_rule_cache",
+        )
 
     @defer.inlineCallbacks
-    def action_for_event_by_user(self, event, context):
-        actions_by_user = {}
+    def _get_rules_for_event(self, event, context):
+        """This gets the rules for all users in the room at the time of the event,
+        as well as the push rules for the invitee if the event is an invite.
+
+        Returns:
+            dict of user_id -> push_rules
+        """
+        room_id = event.room_id
+        rules_for_room = self._get_rules_for_room(room_id)
+
+        rules_by_user = yield rules_for_room.get_rules(event, context)
+
+        # if this event is an invite event, we may need to run rules for the user
+        # who's been invited, otherwise they won't get told they've been invited
+        if event.type == 'm.room.member' and event.content['membership'] == 'invite':
+            invited = event.state_key
+            if invited and self.hs.is_mine_id(invited):
+                has_pusher = yield self.store.user_has_pusher(invited)
+                if has_pusher:
+                    rules_by_user = dict(rules_by_user)
+                    rules_by_user[invited] = yield self.store.get_push_rules_for_user(
+                        invited
+                    )
 
-        # None of these users can be peeking since this list of users comes
-        # from the set of users in the room, so we know for sure they're all
-        # actually in the room.
-        user_tuples = [
-            (u, False) for u in self.rules_by_user.keys()
-        ]
+        defer.returnValue(rules_by_user)
 
-        filtered_by_user = yield filter_events_for_clients_context(
-            self.store, user_tuples, [event], {event.event_id: context}
+    @cached()
+    def _get_rules_for_room(self, room_id):
+        """Get the current RulesForRoom object for the given room id
+
+        Returns:
+            RulesForRoom
+        """
+        # It's important that RulesForRoom gets added to self._get_rules_for_room.cache
+        # before any lookup methods get called on it as otherwise there may be
+        # a race if invalidate_all gets called (which assumes its in the cache)
+        return RulesForRoom(
+            self.hs, room_id, self._get_rules_for_room.cache,
+            self.room_push_rule_cache_metrics,
         )
 
+    @defer.inlineCallbacks
+    def _get_power_levels_and_sender_level(self, event, context):
+        pl_event_id = context.prev_state_ids.get(POWER_KEY)
+        if pl_event_id:
+            # fastpath: if there's a power level event, that's all we need, and
+            # not having a power level event is an extreme edge case
+            pl_event = yield self.store.get_event(pl_event_id)
+            auth_events = {POWER_KEY: pl_event}
+        else:
+            auth_events_ids = yield self.auth.compute_auth_events(
+                event, context.prev_state_ids, for_verification=False,
+            )
+            auth_events = yield self.store.get_events(auth_events_ids)
+            auth_events = {
+                (e.type, e.state_key): e for e in auth_events.itervalues()
+            }
+
+        sender_level = get_user_power_level(event.sender, auth_events)
+
+        pl_event = auth_events.get(POWER_KEY)
+
+        defer.returnValue((pl_event.content if pl_event else {}, sender_level))
+
+    @defer.inlineCallbacks
+    def action_for_event_by_user(self, event, context):
+        """Given an event and context, evaluate the push rules and insert the
+        results into the event_push_actions_staging table.
+
+        Returns:
+            Deferred
+        """
+        rules_by_user = yield self._get_rules_for_event(event, context)
+        actions_by_user = {}
+
         room_members = yield self.store.get_joined_users_from_context(
             event, context
         )
 
-        evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
+        (power_levels, sender_power_level) = (
+            yield self._get_power_levels_and_sender_level(event, context)
+        )
+
+        evaluator = PushRuleEvaluatorForEvent(
+            event, len(room_members), sender_power_level, power_levels,
+        )
 
         condition_cache = {}
 
-        for uid, rules in self.rules_by_user.items():
-            display_name = room_members.get(uid, {}).get("display_name", None)
+        for uid, rules in rules_by_user.iteritems():
+            if event.sender == uid:
+                continue
+
+            if not event.is_state():
+                is_ignored = yield self.store.is_ignored_by(event.sender, uid)
+                if is_ignored:
+                    continue
+
+            display_name = None
+            profile_info = room_members.get(uid)
+            if profile_info:
+                display_name = profile_info.display_name
+
             if not display_name:
                 # Handle the case where we are pushing a membership event to
                 # that user, as they might not be already joined.
                 if event.type == EventTypes.Member and event.state_key == uid:
                     display_name = event.content.get("displayname", None)
 
-            filtered = filtered_by_user[uid]
-            if len(filtered) == 0:
-                continue
-
-            if filtered[0].sender == uid:
-                continue
-
             for rule in rules:
                 if 'enabled' in rule and not rule['enabled']:
                     continue
@@ -111,9 +190,16 @@ class BulkPushRuleEvaluator:
                 if matches:
                     actions = [x for x in rule['actions'] if x != 'dont_notify']
                     if actions and 'notify' in actions:
+                        # Push rules say we should notify the user of this event
                         actions_by_user[uid] = actions
                     break
-        defer.returnValue(actions_by_user)
+
+        # Mark in the DB staging area the push actions for users who should be
+        # notified for this event. (This will then get handled when we persist
+        # the event)
+        yield self.store.add_push_actions_to_staging(
+            event.event_id, actions_by_user,
+        )
 
 
 def _condition_checker(evaluator, conditions, uid, display_name, cache):
@@ -134,3 +220,264 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache):
             return False
 
     return True
+
+
+class RulesForRoom(object):
+    """Caches push rules for users in a room.
+
+    This efficiently handles users joining/leaving the room by not invalidating
+    the entire cache for the room.
+    """
+
+    def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics):
+        """
+        Args:
+            hs (HomeServer)
+            room_id (str)
+            rules_for_room_cache(Cache): The cache object that caches these
+                RoomsForUser objects.
+            room_push_rule_cache_metrics (CacheMetric)
+        """
+        self.room_id = room_id
+        self.is_mine_id = hs.is_mine_id
+        self.store = hs.get_datastore()
+        self.room_push_rule_cache_metrics = room_push_rule_cache_metrics
+
+        self.linearizer = Linearizer(name="rules_for_room")
+
+        self.member_map = {}  # event_id -> (user_id, state)
+        self.rules_by_user = {}  # user_id -> rules
+
+        # The last state group we updated the caches for. If the state_group of
+        # a new event comes along, we know that we can just return the cached
+        # result.
+        # On invalidation of the rules themselves (if the user changes them),
+        # we invalidate everything and set state_group to `object()`
+        self.state_group = object()
+
+        # A sequence number to keep track of when we're allowed to update the
+        # cache. We bump the sequence number when we invalidate the cache. If
+        # the sequence number changes while we're calculating stuff we should
+        # not update the cache with it.
+        self.sequence = 0
+
+        # A cache of user_ids that we *know* aren't interesting, e.g. user_ids
+        # owned by AS's, or remote users, etc. (I.e. users we will never need to
+        # calculate push for)
+        # These never need to be invalidated as we will never set up push for
+        # them.
+        self.uninteresting_user_set = set()
+
+        # We need to be clever on the invalidating caches callbacks, as
+        # otherwise the invalidation callback holds a reference to the object,
+        # potentially causing it to leak.
+        # To get around this we pass a function that on invalidations looks ups
+        # the RoomsForUser entry in the cache, rather than keeping a reference
+        # to self around in the callback.
+        self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
+
+    @defer.inlineCallbacks
+    def get_rules(self, event, context):
+        """Given an event context return the rules for all users who are
+        currently in the room.
+        """
+        state_group = context.state_group
+
+        if state_group and self.state_group == state_group:
+            logger.debug("Using cached rules for %r", self.room_id)
+            self.room_push_rule_cache_metrics.inc_hits()
+            defer.returnValue(self.rules_by_user)
+
+        with (yield self.linearizer.queue(())):
+            if state_group and self.state_group == state_group:
+                logger.debug("Using cached rules for %r", self.room_id)
+                self.room_push_rule_cache_metrics.inc_hits()
+                defer.returnValue(self.rules_by_user)
+
+            self.room_push_rule_cache_metrics.inc_misses()
+
+            ret_rules_by_user = {}
+            missing_member_event_ids = {}
+            if state_group and self.state_group == context.prev_group:
+                # If we have a simple delta then we can reuse most of the previous
+                # results.
+                ret_rules_by_user = self.rules_by_user
+                current_state_ids = context.delta_ids
+
+                push_rules_delta_state_cache_metric.inc_hits()
+            else:
+                current_state_ids = context.current_state_ids
+                push_rules_delta_state_cache_metric.inc_misses()
+
+            push_rules_state_size_counter.inc_by(len(current_state_ids))
+
+            logger.debug(
+                "Looking for member changes in %r %r", state_group, current_state_ids
+            )
+
+            # Loop through to see which member events we've seen and have rules
+            # for and which we need to fetch
+            for key in current_state_ids:
+                typ, user_id = key
+                if typ != EventTypes.Member:
+                    continue
+
+                if user_id in self.uninteresting_user_set:
+                    continue
+
+                if not self.is_mine_id(user_id):
+                    self.uninteresting_user_set.add(user_id)
+                    continue
+
+                if self.store.get_if_app_services_interested_in_user(user_id):
+                    self.uninteresting_user_set.add(user_id)
+                    continue
+
+                event_id = current_state_ids[key]
+
+                res = self.member_map.get(event_id, None)
+                if res:
+                    user_id, state = res
+                    if state == Membership.JOIN:
+                        rules = self.rules_by_user.get(user_id, None)
+                        if rules:
+                            ret_rules_by_user[user_id] = rules
+                    continue
+
+                # If a user has left a room we remove their push rule. If they
+                # joined then we readd it later in _update_rules_with_member_event_ids
+                ret_rules_by_user.pop(user_id, None)
+                missing_member_event_ids[user_id] = event_id
+
+            if missing_member_event_ids:
+                # If we have some memebr events we haven't seen, look them up
+                # and fetch push rules for them if appropriate.
+                logger.debug("Found new member events %r", missing_member_event_ids)
+                yield self._update_rules_with_member_event_ids(
+                    ret_rules_by_user, missing_member_event_ids, state_group, event
+                )
+            else:
+                # The push rules didn't change but lets update the cache anyway
+                self.update_cache(
+                    self.sequence,
+                    members={},  # There were no membership changes
+                    rules_by_user=ret_rules_by_user,
+                    state_group=state_group
+                )
+
+        if logger.isEnabledFor(logging.DEBUG):
+            logger.debug(
+                "Returning push rules for %r %r",
+                self.room_id, ret_rules_by_user.keys(),
+            )
+        defer.returnValue(ret_rules_by_user)
+
+    @defer.inlineCallbacks
+    def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_ids,
+                                            state_group, event):
+        """Update the partially filled rules_by_user dict by fetching rules for
+        any newly joined users in the `member_event_ids` list.
+
+        Args:
+            ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
+                updated with any new rules.
+            member_event_ids (list): List of event ids for membership events that
+                have happened since the last time we filled rules_by_user
+            state_group: The state group we are currently computing push rules
+                for. Used when updating the cache.
+        """
+        sequence = self.sequence
+
+        rows = yield self.store._simple_select_many_batch(
+            table="room_memberships",
+            column="event_id",
+            iterable=member_event_ids.values(),
+            retcols=('user_id', 'membership', 'event_id'),
+            keyvalues={},
+            batch_size=500,
+            desc="_get_rules_for_member_event_ids",
+        )
+
+        members = {
+            row["event_id"]: (row["user_id"], row["membership"])
+            for row in rows
+        }
+
+        # If the event is a join event then it will be in current state evnts
+        # map but not in the DB, so we have to explicitly insert it.
+        if event.type == EventTypes.Member:
+            for event_id in member_event_ids.itervalues():
+                if event_id == event.event_id:
+                    members[event_id] = (event.state_key, event.membership)
+
+        if logger.isEnabledFor(logging.DEBUG):
+            logger.debug("Found members %r: %r", self.room_id, members.values())
+
+        interested_in_user_ids = set(
+            user_id for user_id, membership in members.itervalues()
+            if membership == Membership.JOIN
+        )
+
+        logger.debug("Joined: %r", interested_in_user_ids)
+
+        if_users_with_pushers = yield self.store.get_if_users_have_pushers(
+            interested_in_user_ids,
+            on_invalidate=self.invalidate_all_cb,
+        )
+
+        user_ids = set(
+            uid for uid, have_pusher in if_users_with_pushers.iteritems() if have_pusher
+        )
+
+        logger.debug("With pushers: %r", user_ids)
+
+        users_with_receipts = yield self.store.get_users_with_read_receipts_in_room(
+            self.room_id, on_invalidate=self.invalidate_all_cb,
+        )
+
+        logger.debug("With receipts: %r", users_with_receipts)
+
+        # any users with pushers must be ours: they have pushers
+        for uid in users_with_receipts:
+            if uid in interested_in_user_ids:
+                user_ids.add(uid)
+
+        rules_by_user = yield self.store.bulk_get_push_rules(
+            user_ids, on_invalidate=self.invalidate_all_cb,
+        )
+
+        ret_rules_by_user.update(
+            item for item in rules_by_user.iteritems() if item[0] is not None
+        )
+
+        self.update_cache(sequence, members, ret_rules_by_user, state_group)
+
+    def invalidate_all(self):
+        # Note: Don't hand this function directly to an invalidation callback
+        # as it keeps a reference to self and will stop this instance from being
+        # GC'd if it gets dropped from the rules_to_user cache. Instead use
+        # `self.invalidate_all_cb`
+        logger.debug("Invalidating RulesForRoom for %r", self.room_id)
+        self.sequence += 1
+        self.state_group = object()
+        self.member_map = {}
+        self.rules_by_user = {}
+        push_rules_invalidation_counter.inc()
+
+    def update_cache(self, sequence, members, rules_by_user, state_group):
+        if sequence == self.sequence:
+            self.member_map.update(members)
+            self.rules_by_user = rules_by_user
+            self.state_group = state_group
+
+
+class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
+    # We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
+    # which namedtuple does for us (i.e. two _CacheContext are the same if
+    # their caches and keys match). This is important in particular to
+    # dedupe when we add callbacks to lru cache nodes, otherwise the number
+    # of callbacks would grow.
+    def __call__(self):
+        rules = self.cache.get(self.room_id, None, update_metrics=False)
+        if rules:
+            rules.invalidate_all()
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 2eb325c7c7..ba7286cb72 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -21,7 +21,6 @@ import logging
 from synapse.util.metrics import Measure
 from synapse.util.logcontext import LoggingContext
 
-from mailer import Mailer
 
 logger = logging.getLogger(__name__)
 
@@ -56,8 +55,10 @@ class EmailPusher(object):
     This shares quite a bit of code with httpusher: it would be good to
     factor out the common parts
     """
-    def __init__(self, hs, pusherdict):
+    def __init__(self, hs, pusherdict, mailer):
         self.hs = hs
+        self.mailer = mailer
+
         self.store = self.hs.get_datastore()
         self.clock = self.hs.get_clock()
         self.pusher_id = pusherdict['id']
@@ -73,23 +74,16 @@ class EmailPusher(object):
 
         self.processing = False
 
-        if self.hs.config.email_enable_notifs:
-            if 'data' in pusherdict and 'brand' in pusherdict['data']:
-                app_name = pusherdict['data']['brand']
-            else:
-                app_name = self.hs.config.email_app_name
-
-            self.mailer = Mailer(self.hs, app_name)
-        else:
-            self.mailer = None
-
     @defer.inlineCallbacks
     def on_started(self):
         if self.mailer is not None:
-            self.throttle_params = yield self.store.get_throttle_params_by_room(
-                self.pusher_id
-            )
-            yield self._process()
+            try:
+                self.throttle_params = yield self.store.get_throttle_params_by_room(
+                    self.pusher_id
+                )
+                yield self._process()
+            except Exception:
+                logger.exception("Error starting email pusher")
 
     def on_stop(self):
         if self.timed_call:
@@ -130,7 +124,7 @@ class EmailPusher(object):
                         starting_max_ordering = self.max_stream_ordering
                         try:
                             yield self._unsafe_process()
-                        except:
+                        except Exception:
                             logger.exception("Exception processing notifs")
                         if self.max_stream_ordering == starting_max_ordering:
                             break
@@ -218,7 +212,8 @@ class EmailPusher(object):
         )
 
     def seconds_until(self, ts_msec):
-        return (ts_msec - self.clock.time_msec()) / 1000
+        secs = (ts_msec - self.clock.time_msec()) / 1000
+        return max(secs, 0)
 
     def get_room_throttle_ms(self, room_id):
         if room_id in self.throttle_params:
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index c0f8176e3d..b077e1a446 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,21 +13,30 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-from synapse.push import PusherConfigException
+import logging
 
 from twisted.internet import defer, reactor
 from twisted.internet.error import AlreadyCalled, AlreadyCancelled
 
-import logging
-import push_rule_evaluator
-import push_tools
-
+from . import push_rule_evaluator
+from . import push_tools
+import synapse
+from synapse.push import PusherConfigException
 from synapse.util.logcontext import LoggingContext
 from synapse.util.metrics import Measure
 
 logger = logging.getLogger(__name__)
 
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+http_push_processed_counter = metrics.register_counter(
+    "http_pushes_processed",
+)
+
+http_push_failed_counter = metrics.register_counter(
+    "http_pushes_failed",
+)
+
 
 class HttpPusher(object):
     INITIAL_BACKOFF_SEC = 1  # in seconds because that's what Twisted takes
@@ -84,7 +94,10 @@ class HttpPusher(object):
 
     @defer.inlineCallbacks
     def on_started(self):
-        yield self._process()
+        try:
+            yield self._process()
+        except Exception:
+            logger.exception("Error starting http pusher")
 
     @defer.inlineCallbacks
     def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
@@ -131,7 +144,7 @@ class HttpPusher(object):
                         starting_max_ordering = self.max_stream_ordering
                         try:
                             yield self._unsafe_process()
-                        except:
+                        except Exception:
                             logger.exception("Exception processing notifs")
                         if self.max_stream_ordering == starting_max_ordering:
                             break
@@ -151,9 +164,16 @@ class HttpPusher(object):
             self.user_id, self.last_stream_ordering, self.max_stream_ordering
         )
 
+        logger.info(
+            "Processing %i unprocessed push actions for %s starting at "
+            "stream_ordering %s",
+            len(unprocessed), self.name, self.last_stream_ordering,
+        )
+
         for push_action in unprocessed:
             processed = yield self._process_one(push_action)
             if processed:
+                http_push_processed_counter.inc()
                 self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
                 self.last_stream_ordering = push_action['stream_ordering']
                 yield self.store.update_pusher_last_stream_ordering_and_success(
@@ -168,6 +188,7 @@ class HttpPusher(object):
                         self.failing_since
                     )
             else:
+                http_push_failed_counter.inc()
                 if not self.failing_since:
                     self.failing_since = self.clock.time_msec()
                     yield self.store.update_pusher_failing_since(
@@ -244,6 +265,26 @@ class HttpPusher(object):
 
     @defer.inlineCallbacks
     def _build_notification_dict(self, event, tweaks, badge):
+        if self.data.get('format') == 'event_id_only':
+            d = {
+                'notification': {
+                    'event_id': event.event_id,
+                    'room_id': event.room_id,
+                    'counts': {
+                        'unread': badge,
+                    },
+                    'devices': [
+                        {
+                            'app_id': self.app_id,
+                            'pushkey': self.pushkey,
+                            'pushkey_ts': long(self.pushkey_ts / 1000),
+                            'data': self.data_minus_url,
+                        }
+                    ]
+                }
+            }
+            defer.returnValue(d)
+
         ctx = yield push_tools.get_context_for_event(
             self.store, self.state_handler, event, self.user_id
         )
@@ -275,7 +316,7 @@ class HttpPusher(object):
         if event.type == 'm.room.member':
             d['notification']['membership'] = event.content['membership']
             d['notification']['user_is_target'] = event.state_key == self.user_id
-        if 'content' in event:
+        if self.hs.config.push_include_content and 'content' in event:
             d['notification']['content'] = event.content
 
         # We no longer send aliases separately, instead, we send the human
@@ -294,8 +335,11 @@ class HttpPusher(object):
             defer.returnValue([])
         try:
             resp = yield self.http_client.post_json_get_json(self.url, notification_dict)
-        except:
-            logger.warn("Failed to push %s ", self.url)
+        except Exception:
+            logger.warn(
+                "Failed to push event %s to %s",
+                event.event_id, self.name, exc_info=True,
+            )
             defer.returnValue(False)
         rejected = []
         if 'rejected' in resp:
@@ -304,7 +348,7 @@ class HttpPusher(object):
 
     @defer.inlineCallbacks
     def _send_badge(self, badge):
-        logger.info("Sending updated badge count %d to %r", badge, self.user_id)
+        logger.info("Sending updated badge count %d to %s", badge, self.name)
         d = {
             'notification': {
                 'id': '',
@@ -325,8 +369,11 @@ class HttpPusher(object):
         }
         try:
             resp = yield self.http_client.post_json_get_json(self.url, d)
-        except:
-            logger.exception("Failed to push %s ", self.url)
+        except Exception:
+            logger.warn(
+                "Failed to send badge count to %s",
+                self.name, exc_info=True,
+            )
             defer.returnValue(False)
         rejected = []
         if 'rejected' in resp:
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 62d794f22b..b5cd9b426a 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -78,23 +78,17 @@ ALLOWED_ATTRS = {
 
 
 class Mailer(object):
-    def __init__(self, hs, app_name):
+    def __init__(self, hs, app_name, notif_template_html, notif_template_text):
         self.hs = hs
+        self.notif_template_html = notif_template_html
+        self.notif_template_text = notif_template_text
+
         self.store = self.hs.get_datastore()
         self.macaroon_gen = self.hs.get_macaroon_generator()
         self.state_handler = self.hs.get_state_handler()
-        loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
         self.app_name = app_name
+
         logger.info("Created Mailer for app_name %s" % app_name)
-        env = jinja2.Environment(loader=loader)
-        env.filters["format_ts"] = format_ts_filter
-        env.filters["mxc_to_http"] = self.mxc_to_http_filter
-        self.notif_template_html = env.get_template(
-            self.hs.config.email_notif_template_html
-        )
-        self.notif_template_text = env.get_template(
-            self.hs.config.email_notif_template_text
-        )
 
     @defer.inlineCallbacks
     def send_notification_mail(self, app_id, user_id, email_address,
@@ -139,7 +133,7 @@ class Mailer(object):
 
         @defer.inlineCallbacks
         def _fetch_room_state(room_id):
-            room_state = yield self.state_handler.get_current_state_ids(room_id)
+            room_state = yield self.store.get_current_state_ids(room_id)
             state_by_room[room_id] = room_state
 
         # Run at most 3 of these at once: sync does 10 at a time but email
@@ -200,7 +194,11 @@ class Mailer(object):
         yield sendmail(
             self.hs.config.email_smtp_host,
             raw_from, raw_to, multipart_msg.as_string(),
-            port=self.hs.config.email_smtp_port
+            port=self.hs.config.email_smtp_port,
+            requireAuthentication=self.hs.config.email_smtp_user is not None,
+            username=self.hs.config.email_smtp_user,
+            password=self.hs.config.email_smtp_pass,
+            requireTransportSecurity=self.hs.config.require_transport_security
         )
 
     @defer.inlineCallbacks
@@ -477,28 +475,6 @@ class Mailer(object):
             urllib.urlencode(params),
         )
 
-    def mxc_to_http_filter(self, value, width, height, resize_method="crop"):
-        if value[0:6] != "mxc://":
-            return ""
-
-        serverAndMediaId = value[6:]
-        fragment = None
-        if '#' in serverAndMediaId:
-            (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1)
-            fragment = "#" + fragment
-
-        params = {
-            "width": width,
-            "height": height,
-            "method": resize_method,
-        }
-        return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
-            self.hs.config.public_baseurl,
-            serverAndMediaId,
-            urllib.urlencode(params),
-            fragment or "",
-        )
-
 
 def safe_markup(raw_html):
     return jinja2.Markup(bleach.linkify(bleach.clean(
@@ -539,3 +515,52 @@ def string_ordinal_total(s):
 
 def format_ts_filter(value, format):
     return time.strftime(format, time.localtime(value / 1000))
+
+
+def load_jinja2_templates(config):
+    """Load the jinja2 email templates from disk
+
+    Returns:
+        (notif_template_html, notif_template_text)
+    """
+    logger.info("loading jinja2")
+
+    loader = jinja2.FileSystemLoader(config.email_template_dir)
+    env = jinja2.Environment(loader=loader)
+    env.filters["format_ts"] = format_ts_filter
+    env.filters["mxc_to_http"] = _create_mxc_to_http_filter(config)
+
+    notif_template_html = env.get_template(
+        config.email_notif_template_html
+    )
+    notif_template_text = env.get_template(
+        config.email_notif_template_text
+    )
+
+    return notif_template_html, notif_template_text
+
+
+def _create_mxc_to_http_filter(config):
+    def mxc_to_http_filter(value, width, height, resize_method="crop"):
+        if value[0:6] != "mxc://":
+            return ""
+
+        serverAndMediaId = value[6:]
+        fragment = None
+        if '#' in serverAndMediaId:
+            (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1)
+            fragment = "#" + fragment
+
+        params = {
+            "width": width,
+            "height": height,
+            "method": resize_method,
+        }
+        return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
+            config.public_baseurl,
+            serverAndMediaId,
+            urllib.urlencode(params),
+            fragment or "",
+        )
+
+    return mxc_to_http_filter
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 4db76f18bd..3601f2d365 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,6 +18,7 @@ import logging
 import re
 
 from synapse.types import UserID
+from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
 from synapse.util.caches.lrucache import LruCache
 
 logger = logging.getLogger(__name__)
@@ -28,6 +30,21 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
 
 
 def _room_member_count(ev, condition, room_member_count):
+    return _test_ineq_condition(condition, room_member_count)
+
+
+def _sender_notification_permission(ev, condition, sender_power_level, power_levels):
+    notif_level_key = condition.get('key')
+    if notif_level_key is None:
+        return False
+
+    notif_levels = power_levels.get('notifications', {})
+    room_notif_level = notif_levels.get(notif_level_key, 50)
+
+    return sender_power_level >= room_notif_level
+
+
+def _test_ineq_condition(condition, number):
     if 'is' not in condition:
         return False
     m = INEQUALITY_EXPR.match(condition['is'])
@@ -40,15 +57,15 @@ def _room_member_count(ev, condition, room_member_count):
     rhs = int(rhs)
 
     if ineq == '' or ineq == '==':
-        return room_member_count == rhs
+        return number == rhs
     elif ineq == '<':
-        return room_member_count < rhs
+        return number < rhs
     elif ineq == '>':
-        return room_member_count > rhs
+        return number > rhs
     elif ineq == '>=':
-        return room_member_count >= rhs
+        return number >= rhs
     elif ineq == '<=':
-        return room_member_count <= rhs
+        return number <= rhs
     else:
         return False
 
@@ -64,9 +81,11 @@ def tweaks_for_actions(actions):
 
 
 class PushRuleEvaluatorForEvent(object):
-    def __init__(self, event, room_member_count):
+    def __init__(self, event, room_member_count, sender_power_level, power_levels):
         self._event = event
         self._room_member_count = room_member_count
+        self._sender_power_level = sender_power_level
+        self._power_levels = power_levels
 
         # Maps strings of e.g. 'content.body' -> event["content"]["body"]
         self._value_cache = _flatten_dict(event)
@@ -80,6 +99,10 @@ class PushRuleEvaluatorForEvent(object):
             return _room_member_count(
                 self._event, condition, self._room_member_count
             )
+        elif condition['kind'] == 'sender_notification_permission':
+            return _sender_notification_permission(
+                self._event, condition, self._sender_power_level, self._power_levels,
+            )
         else:
             return True
 
@@ -125,6 +148,11 @@ class PushRuleEvaluatorForEvent(object):
         return self._value_cache.get(dotted_key, None)
 
 
+# Caches (glob, word_boundary) -> regex for push. See _glob_matches
+regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
+register_cache("regex_push_cache", regex_cache)
+
+
 def _glob_matches(glob, value, word_boundary=False):
     """Tests if value matches glob.
 
@@ -137,47 +165,78 @@ def _glob_matches(glob, value, word_boundary=False):
     Returns:
         bool
     """
+
     try:
-        if IS_GLOB.search(glob):
-            r = re.escape(glob)
-
-            r = r.replace(r'\*', '.*?')
-            r = r.replace(r'\?', '.')
-
-            # handle [abc], [a-z] and [!a-z] style ranges.
-            r = GLOB_REGEX.sub(
-                lambda x: (
-                    '[%s%s]' % (
-                        x.group(1) and '^' or '',
-                        x.group(2).replace(r'\\\-', '-')
-                    )
-                ),
-                r,
-            )
-            if word_boundary:
-                r = r"\b%s\b" % (r,)
-                r = _compile_regex(r)
-
-                return r.search(value)
-            else:
-                r = r + "$"
-                r = _compile_regex(r)
-
-                return r.match(value)
-        elif word_boundary:
-            r = re.escape(glob)
-            r = r"\b%s\b" % (r,)
-            r = _compile_regex(r)
-
-            return r.search(value)
-        else:
-            return value.lower() == glob.lower()
+        r = regex_cache.get((glob, word_boundary), None)
+        if not r:
+            r = _glob_to_re(glob, word_boundary)
+            regex_cache[(glob, word_boundary)] = r
+        return r.search(value)
     except re.error:
         logger.warn("Failed to parse glob to regex: %r", glob)
         return False
 
 
-def _flatten_dict(d, prefix=[], result={}):
+def _glob_to_re(glob, word_boundary):
+    """Generates regex for a given glob.
+
+    Args:
+        glob (string)
+        word_boundary (bool): Whether to match against word boundaries or entire
+            string. Defaults to False.
+
+    Returns:
+        regex object
+    """
+    if IS_GLOB.search(glob):
+        r = re.escape(glob)
+
+        r = r.replace(r'\*', '.*?')
+        r = r.replace(r'\?', '.')
+
+        # handle [abc], [a-z] and [!a-z] style ranges.
+        r = GLOB_REGEX.sub(
+            lambda x: (
+                '[%s%s]' % (
+                    x.group(1) and '^' or '',
+                    x.group(2).replace(r'\\\-', '-')
+                )
+            ),
+            r,
+        )
+        if word_boundary:
+            r = _re_word_boundary(r)
+
+            return re.compile(r, flags=re.IGNORECASE)
+        else:
+            r = "^" + r + "$"
+
+            return re.compile(r, flags=re.IGNORECASE)
+    elif word_boundary:
+        r = re.escape(glob)
+        r = _re_word_boundary(r)
+
+        return re.compile(r, flags=re.IGNORECASE)
+    else:
+        r = "^" + re.escape(glob) + "$"
+        return re.compile(r, flags=re.IGNORECASE)
+
+
+def _re_word_boundary(r):
+    """
+    Adds word boundary characters to the start and end of an
+    expression to require that the match occur as a whole word,
+    but do so respecting the fact that strings starting or ending
+    with non-word characters will change word boundaries.
+    """
+    # we can't use \b as it chokes on unicode. however \W seems to be okay
+    # as shorthand for [^0-9A-Za-z_].
+    return r"(^|\W)%s(\W|$)" % (r,)
+
+
+def _flatten_dict(d, prefix=[], result=None):
+    if result is None:
+        result = {}
     for key, value in d.items():
         if isinstance(value, basestring):
             result[".".join(prefix + [key])] = value.lower()
@@ -185,16 +244,3 @@ def _flatten_dict(d, prefix=[], result={}):
             _flatten_dict(value, prefix=(prefix + [key]), result=result)
 
     return result
-
-
-regex_cache = LruCache(5000)
-
-
-def _compile_regex(regex_str):
-    r = regex_cache.get(regex_str, None)
-    if r:
-        return r
-
-    r = re.compile(regex_str, flags=re.IGNORECASE)
-    regex_cache[regex_str] = r
-    return r
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index a27476bbad..6835f54e97 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -17,15 +17,12 @@ from twisted.internet import defer
 from synapse.push.presentable_names import (
     calculate_room_name, name_from_member_event
 )
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 
 
 @defer.inlineCallbacks
 def get_badge_count(store, user_id):
-    invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
-        preserve_fn(store.get_invited_rooms_for_user)(user_id),
-        preserve_fn(store.get_rooms_for_user)(user_id),
-    ], consumeErrors=True))
+    invites = yield store.get_invited_rooms_for_user(user_id)
+    joins = yield store.get_rooms_for_user(user_id)
 
     my_receipts_by_room = yield store.get_receipts_for_user(
         user_id, "m.read",
@@ -33,13 +30,13 @@ def get_badge_count(store, user_id):
 
     badge = len(invites)
 
-    for r in joins:
-        if r.room_id in my_receipts_by_room:
-            last_unread_event_id = my_receipts_by_room[r.room_id]
+    for room_id in joins:
+        if room_id in my_receipts_by_room:
+            last_unread_event_id = my_receipts_by_room[room_id]
 
             notifs = yield (
                 store.get_unread_event_push_actions_by_room_for_user(
-                    r.room_id, user_id, last_unread_event_id
+                    room_id, user_id, last_unread_event_id
                 )
             )
             # return one badge count per conversation, as count per
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index de9c33b936..5aa6667e91 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from httppusher import HttpPusher
+from .httppusher import HttpPusher
 
 import logging
 logger = logging.getLogger(__name__)
@@ -26,22 +26,54 @@ logger = logging.getLogger(__name__)
 # process works fine)
 try:
     from synapse.push.emailpusher import EmailPusher
-except:
+    from synapse.push.mailer import Mailer, load_jinja2_templates
+except Exception:
     pass
 
 
-def create_pusher(hs, pusherdict):
-    logger.info("trying to create_pusher for %r", pusherdict)
+class PusherFactory(object):
+    def __init__(self, hs):
+        self.hs = hs
 
-    PUSHER_TYPES = {
-        "http": HttpPusher,
-    }
+        self.pusher_types = {
+            "http": HttpPusher,
+        }
 
-    logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
-    if hs.config.email_enable_notifs:
-        PUSHER_TYPES["email"] = EmailPusher
-        logger.info("defined email pusher type")
+        logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
+        if hs.config.email_enable_notifs:
+            self.mailers = {}  # app_name -> Mailer
 
-    if pusherdict['kind'] in PUSHER_TYPES:
-        logger.info("found pusher")
-        return PUSHER_TYPES[pusherdict['kind']](hs, pusherdict)
+            templates = load_jinja2_templates(hs.config)
+            self.notif_template_html, self.notif_template_text = templates
+
+            self.pusher_types["email"] = self._create_email_pusher
+
+            logger.info("defined email pusher type")
+
+    def create_pusher(self, pusherdict):
+        logger.info("trying to create_pusher for %r", pusherdict)
+
+        if pusherdict['kind'] in self.pusher_types:
+            logger.info("found pusher")
+            return self.pusher_types[pusherdict['kind']](self.hs, pusherdict)
+
+    def _create_email_pusher(self, _hs, pusherdict):
+        app_name = self._app_name_from_pusherdict(pusherdict)
+        mailer = self.mailers.get(app_name)
+        if not mailer:
+            mailer = Mailer(
+                hs=self.hs,
+                app_name=app_name,
+                notif_template_html=self.notif_template_html,
+                notif_template_text=self.notif_template_text,
+            )
+            self.mailers[app_name] = mailer
+        return EmailPusher(self.hs, pusherdict, mailer)
+
+    def _app_name_from_pusherdict(self, pusherdict):
+        if 'data' in pusherdict and 'brand' in pusherdict['data']:
+            app_name = pusherdict['data']['brand']
+        else:
+            app_name = self.hs.config.email_app_name
+
+        return app_name
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 3837be523d..750d11ca38 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -14,13 +14,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import logging
+
 from twisted.internet import defer
 
-import pusher
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.push.pusher import PusherFactory
 from synapse.util.async import run_on_reactor
-
-import logging
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
 
 logger = logging.getLogger(__name__)
 
@@ -28,6 +28,7 @@ logger = logging.getLogger(__name__)
 class PusherPool:
     def __init__(self, _hs):
         self.hs = _hs
+        self.pusher_factory = PusherFactory(_hs)
         self.start_pushers = _hs.config.start_pushers
         self.store = self.hs.get_datastore()
         self.clock = self.hs.get_clock()
@@ -48,7 +49,7 @@ class PusherPool:
         # will then get pulled out of the database,
         # recreated, added and started: this means we have only one
         # code path adding pushers.
-        pusher.create_pusher(self.hs, {
+        self.pusher_factory.create_pusher({
             "id": None,
             "user_name": user_id,
             "kind": kind,
@@ -102,19 +103,25 @@ class PusherPool:
                 yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
 
     @defer.inlineCallbacks
-    def remove_pushers_by_user(self, user_id, except_access_token_id=None):
-        all = yield self.store.get_all_pushers()
-        logger.info(
-            "Removing all pushers for user %s except access tokens id %r",
-            user_id, except_access_token_id
-        )
-        for p in all:
-            if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
+    def remove_pushers_by_access_token(self, user_id, access_tokens):
+        """Remove the pushers for a given user corresponding to a set of
+        access_tokens.
+
+        Args:
+            user_id (str): user to remove pushers for
+            access_tokens (Iterable[int]): access token *ids* to remove pushers
+                for
+        """
+        tokens = set(access_tokens)
+        for p in (yield self.store.get_pushers_by_user_id(user_id)):
+            if p['access_token'] in tokens:
                 logger.info(
                     "Removing pusher for app id %s, pushkey %s, user %s",
                     p['app_id'], p['pushkey'], p['user_name']
                 )
-                yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
+                yield self.remove_pusher(
+                    p['app_id'], p['pushkey'], p['user_name'],
+                )
 
     @defer.inlineCallbacks
     def on_new_notifications(self, min_stream_id, max_stream_id):
@@ -130,13 +137,16 @@ class PusherPool:
                 if u in self.pushers:
                     for p in self.pushers[u].values():
                         deferreds.append(
-                            preserve_fn(p.on_new_notifications)(
-                                min_stream_id, max_stream_id
+                            run_in_background(
+                                p.on_new_notifications,
+                                min_stream_id, max_stream_id,
                             )
                         )
 
-            yield preserve_context_over_deferred(defer.gatherResults(deferreds))
-        except:
+            yield make_deferred_yieldable(
+                defer.gatherResults(deferreds, consumeErrors=True),
+            )
+        except Exception:
             logger.exception("Exception in pusher on_new_notifications")
 
     @defer.inlineCallbacks
@@ -157,11 +167,16 @@ class PusherPool:
                 if u in self.pushers:
                     for p in self.pushers[u].values():
                         deferreds.append(
-                            preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
+                            run_in_background(
+                                p.on_new_receipts,
+                                min_stream_id, max_stream_id,
+                            )
                         )
 
-            yield preserve_context_over_deferred(defer.gatherResults(deferreds))
-        except:
+            yield make_deferred_yieldable(
+                defer.gatherResults(deferreds, consumeErrors=True),
+            )
+        except Exception:
             logger.exception("Exception in pusher on_new_receipts")
 
     @defer.inlineCallbacks
@@ -186,8 +201,8 @@ class PusherPool:
         logger.info("Starting %d pushers", len(pushers))
         for pusherdict in pushers:
             try:
-                p = pusher.create_pusher(self.hs, pusherdict)
-            except:
+                p = self.pusher_factory.create_pusher(pusherdict)
+            except Exception:
                 logger.exception("Couldn't start a pusher: caught Exception")
                 continue
             if p:
@@ -200,7 +215,7 @@ class PusherPool:
                 if appid_pushkey in byuser:
                     byuser[appid_pushkey].on_stop()
                 byuser[appid_pushkey] = p
-                preserve_fn(p.on_started)()
+                run_in_background(p.on_started)
 
         logger.info("Started pushers")
 
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 7817b0cd91..216db4d164 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -1,4 +1,6 @@
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,26 +19,43 @@ from distutils.version import LooseVersion
 
 logger = logging.getLogger(__name__)
 
+# this dict maps from python package name to a list of modules we expect it to
+# provide.
+#
+# the key is a "requirement specifier", as used as a parameter to `pip
+# install`[1], or an `install_requires` argument to `setuptools.setup` [2].
+#
+# the value is a sequence of strings; each entry should be the name of the
+# python module, optionally followed by a version assertion which can be either
+# ">=<ver>" or "==<ver>".
+#
+# [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers.
+# [2] https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-dependencies
 REQUIREMENTS = {
+    "jsonschema>=2.5.1": ["jsonschema>=2.5.1"],
     "frozendict>=0.4": ["frozendict"],
     "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
-    "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
+    "canonicaljson>=1.1.3": ["canonicaljson>=1.1.3"],
     "signedjson>=1.0.0": ["signedjson>=1.0.0"],
-    "pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"],
+    "pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"],
     "service_identity>=1.0.0": ["service_identity>=1.0.0"],
     "Twisted>=16.0.0": ["twisted>=16.0.0"],
-    "pyopenssl>=0.14": ["OpenSSL>=0.14"],
+
+    # We use crypto.get_elliptic_curve which is only supported in >=0.15
+    "pyopenssl>=0.15": ["OpenSSL>=0.15"],
+
     "pyyaml": ["yaml"],
     "pyasn1": ["pyasn1"],
     "daemonize": ["daemonize"],
-    "py-bcrypt": ["bcrypt"],
+    "bcrypt": ["bcrypt>=3.1.0"],
     "pillow": ["PIL"],
     "pydenticon": ["pydenticon"],
-    "ujson": ["ujson"],
     "blist": ["blist"],
-    "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
+    "pysaml2>=3.0.0": ["saml2>=3.0.0"],
     "pymacaroons-pynacl": ["pymacaroons"],
     "msgpack-python>=0.3.0": ["msgpack"],
+    "phonenumbers>=8.2.0": ["phonenumbers"],
+    "six": ["six"],
 }
 CONDITIONAL_REQUIREMENTS = {
     "web_client": {
@@ -55,6 +74,9 @@ CONDITIONAL_REQUIREMENTS = {
     "psutil": {
         "psutil>=2.0.0": ["psutil>=2.0.0"],
     },
+    "affinity": {
+        "affinity": ["affinity"],
+    },
 }
 
 
diff --git a/synapse/replication/expire_cache.py b/synapse/replication/expire_cache.py
deleted file mode 100644
index c05a50d7a6..0000000000
--- a/synapse/replication/expire_cache.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.http.server import respond_with_json_bytes, request_handler
-from synapse.http.servlet import parse_json_object_from_request
-
-from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
-
-
-class ExpireCacheResource(Resource):
-    """
-    HTTP endpoint for expiring storage caches.
-
-    POST /_synapse/replication/expire_cache HTTP/1.1
-    Content-Type: application/json
-
-    {
-        "invalidate": [
-            {
-                "name": "func_name",
-                "keys": ["key1", "key2"]
-            }
-        ]
-    }
-    """
-
-    def __init__(self, hs):
-        Resource.__init__(self)  # Resource is old-style, so no super()
-
-        self.store = hs.get_datastore()
-        self.version_string = hs.version_string
-        self.clock = hs.get_clock()
-
-    def render_POST(self, request):
-        self._async_render_POST(request)
-        return NOT_DONE_YET
-
-    @request_handler()
-    def _async_render_POST(self, request):
-        content = parse_json_object_from_request(request)
-
-        for row in content["invalidate"]:
-            name = row["name"]
-            keys = tuple(row["keys"])
-
-            getattr(self.store, name).invalidate(keys)
-
-        respond_with_json_bytes(request, 200, "{}")
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
new file mode 100644
index 0000000000..1d7a607529
--- /dev/null
+++ b/synapse/replication/http/__init__.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.http.server import JsonResource
+from synapse.replication.http import membership, send_event
+
+
+REPLICATION_PREFIX = "/_synapse/replication"
+
+
+class ReplicationRestResource(JsonResource):
+    def __init__(self, hs):
+        JsonResource.__init__(self, hs, canonical_json=False)
+        self.register_servlets(hs)
+
+    def register_servlets(self, hs):
+        send_event.register_servlets(hs, self)
+        membership.register_servlets(hs, self)
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
new file mode 100644
index 0000000000..e66c4e881f
--- /dev/null
+++ b/synapse/replication/http/membership.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError, MatrixCodeMessageException
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import Requester, UserID
+from synapse.util.distributor import user_left_room, user_joined_room
+
+logger = logging.getLogger(__name__)
+
+
+@defer.inlineCallbacks
+def remote_join(client, host, port, requester, remote_room_hosts,
+                room_id, user_id, content):
+    """Ask the master to do a remote join for the given user to the given room
+
+    Args:
+        client (SimpleHttpClient)
+        host (str): host of master
+        port (int): port on master listening for HTTP replication
+        requester (Requester)
+        remote_room_hosts (list[str]): Servers to try and join via
+        room_id (str)
+        user_id (str)
+        content (dict): The event content to use for the join event
+
+    Returns:
+        Deferred
+    """
+    uri = "http://%s:%s/_synapse/replication/remote_join" % (host, port)
+
+    payload = {
+        "requester": requester.serialize(),
+        "remote_room_hosts": remote_room_hosts,
+        "room_id": room_id,
+        "user_id": user_id,
+        "content": content,
+    }
+
+    try:
+        result = yield client.post_json_get_json(uri, payload)
+    except MatrixCodeMessageException as e:
+        # We convert to SynapseError as we know that it was a SynapseError
+        # on the master process that we should send to the client. (And
+        # importantly, not stack traces everywhere)
+        raise SynapseError(e.code, e.msg, e.errcode)
+    defer.returnValue(result)
+
+
+@defer.inlineCallbacks
+def remote_reject_invite(client, host, port, requester, remote_room_hosts,
+                         room_id, user_id):
+    """Ask master to reject the invite for the user and room.
+
+    Args:
+        client (SimpleHttpClient)
+        host (str): host of master
+        port (int): port on master listening for HTTP replication
+        requester (Requester)
+        remote_room_hosts (list[str]): Servers to try and reject via
+        room_id (str)
+        user_id (str)
+
+    Returns:
+        Deferred
+    """
+    uri = "http://%s:%s/_synapse/replication/remote_reject_invite" % (host, port)
+
+    payload = {
+        "requester": requester.serialize(),
+        "remote_room_hosts": remote_room_hosts,
+        "room_id": room_id,
+        "user_id": user_id,
+    }
+
+    try:
+        result = yield client.post_json_get_json(uri, payload)
+    except MatrixCodeMessageException as e:
+        # We convert to SynapseError as we know that it was a SynapseError
+        # on the master process that we should send to the client. (And
+        # importantly, not stack traces everywhere)
+        raise SynapseError(e.code, e.msg, e.errcode)
+    defer.returnValue(result)
+
+
+@defer.inlineCallbacks
+def get_or_register_3pid_guest(client, host, port, requester,
+                               medium, address, inviter_user_id):
+    """Ask the master to get/create a guest account for given 3PID.
+
+    Args:
+        client (SimpleHttpClient)
+        host (str): host of master
+        port (int): port on master listening for HTTP replication
+        requester (Requester)
+        medium (str)
+        address (str)
+        inviter_user_id (str): The user ID who is trying to invite the
+            3PID
+
+    Returns:
+        Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
+        3PID guest account.
+    """
+
+    uri = "http://%s:%s/_synapse/replication/get_or_register_3pid_guest" % (host, port)
+
+    payload = {
+        "requester": requester.serialize(),
+        "medium": medium,
+        "address": address,
+        "inviter_user_id": inviter_user_id,
+    }
+
+    try:
+        result = yield client.post_json_get_json(uri, payload)
+    except MatrixCodeMessageException as e:
+        # We convert to SynapseError as we know that it was a SynapseError
+        # on the master process that we should send to the client. (And
+        # importantly, not stack traces everywhere)
+        raise SynapseError(e.code, e.msg, e.errcode)
+    defer.returnValue(result)
+
+
+@defer.inlineCallbacks
+def notify_user_membership_change(client, host, port, user_id, room_id, change):
+    """Notify master that a user has joined or left the room
+
+    Args:
+        client (SimpleHttpClient)
+        host (str): host of master
+        port (int): port on master listening for HTTP replication.
+        user_id (str)
+        room_id (str)
+        change (str): Either "join" or "left"
+
+    Returns:
+        Deferred
+    """
+    assert change in ("joined", "left")
+
+    uri = "http://%s:%s/_synapse/replication/user_%s_room" % (host, port, change)
+
+    payload = {
+        "user_id": user_id,
+        "room_id": room_id,
+    }
+
+    try:
+        result = yield client.post_json_get_json(uri, payload)
+    except MatrixCodeMessageException as e:
+        # We convert to SynapseError as we know that it was a SynapseError
+        # on the master process that we should send to the client. (And
+        # importantly, not stack traces everywhere)
+        raise SynapseError(e.code, e.msg, e.errcode)
+    defer.returnValue(result)
+
+
+class ReplicationRemoteJoinRestServlet(RestServlet):
+    PATTERNS = [re.compile("^/_synapse/replication/remote_join$")]
+
+    def __init__(self, hs):
+        super(ReplicationRemoteJoinRestServlet, self).__init__()
+
+        self.federation_handler = hs.get_handlers().federation_handler
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        content = parse_json_object_from_request(request)
+
+        remote_room_hosts = content["remote_room_hosts"]
+        room_id = content["room_id"]
+        user_id = content["user_id"]
+        event_content = content["content"]
+
+        requester = Requester.deserialize(self.store, content["requester"])
+
+        if requester.user:
+            request.authenticated_entity = requester.user.to_string()
+
+        logger.info(
+            "remote_join: %s into room: %s",
+            user_id, room_id,
+        )
+
+        yield self.federation_handler.do_invite_join(
+            remote_room_hosts,
+            room_id,
+            user_id,
+            event_content,
+        )
+
+        defer.returnValue((200, {}))
+
+
+class ReplicationRemoteRejectInviteRestServlet(RestServlet):
+    PATTERNS = [re.compile("^/_synapse/replication/remote_reject_invite$")]
+
+    def __init__(self, hs):
+        super(ReplicationRemoteRejectInviteRestServlet, self).__init__()
+
+        self.federation_handler = hs.get_handlers().federation_handler
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        content = parse_json_object_from_request(request)
+
+        remote_room_hosts = content["remote_room_hosts"]
+        room_id = content["room_id"]
+        user_id = content["user_id"]
+
+        requester = Requester.deserialize(self.store, content["requester"])
+
+        if requester.user:
+            request.authenticated_entity = requester.user.to_string()
+
+        logger.info(
+            "remote_reject_invite: %s out of room: %s",
+            user_id, room_id,
+        )
+
+        try:
+            event = yield self.federation_handler.do_remotely_reject_invite(
+                remote_room_hosts,
+                room_id,
+                user_id,
+            )
+            ret = event.get_pdu_json()
+        except Exception as e:
+            # if we were unable to reject the exception, just mark
+            # it as rejected on our end and plough ahead.
+            #
+            # The 'except' clause is very broad, but we need to
+            # capture everything from DNS failures upwards
+            #
+            logger.warn("Failed to reject invite: %s", e)
+
+            yield self.store.locally_reject_invite(
+                user_id, room_id
+            )
+            ret = {}
+
+        defer.returnValue((200, ret))
+
+
+class ReplicationRegister3PIDGuestRestServlet(RestServlet):
+    PATTERNS = [re.compile("^/_synapse/replication/get_or_register_3pid_guest$")]
+
+    def __init__(self, hs):
+        super(ReplicationRegister3PIDGuestRestServlet, self).__init__()
+
+        self.registeration_handler = hs.get_handlers().registration_handler
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        content = parse_json_object_from_request(request)
+
+        medium = content["medium"]
+        address = content["address"]
+        inviter_user_id = content["inviter_user_id"]
+
+        requester = Requester.deserialize(self.store, content["requester"])
+
+        if requester.user:
+            request.authenticated_entity = requester.user.to_string()
+
+        logger.info("get_or_register_3pid_guest: %r", content)
+
+        ret = yield self.registeration_handler.get_or_register_3pid_guest(
+            medium, address, inviter_user_id,
+        )
+
+        defer.returnValue((200, ret))
+
+
+class ReplicationUserJoinedLeftRoomRestServlet(RestServlet):
+    PATTERNS = [re.compile("^/_synapse/replication/user_(?P<change>joined|left)_room$")]
+
+    def __init__(self, hs):
+        super(ReplicationUserJoinedLeftRoomRestServlet, self).__init__()
+
+        self.registeration_handler = hs.get_handlers().registration_handler
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+        self.distributor = hs.get_distributor()
+
+    def on_POST(self, request, change):
+        content = parse_json_object_from_request(request)
+
+        user_id = content["user_id"]
+        room_id = content["room_id"]
+
+        logger.info("user membership change: %s in %s", user_id, room_id)
+
+        user = UserID.from_string(user_id)
+
+        if change == "joined":
+            user_joined_room(self.distributor, user, room_id)
+        elif change == "left":
+            user_left_room(self.distributor, user, room_id)
+        else:
+            raise Exception("Unrecognized change: %r", change)
+
+        return (200, {})
+
+
+def register_servlets(hs, http_server):
+    ReplicationRemoteJoinRestServlet(hs).register(http_server)
+    ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
+    ReplicationRegister3PIDGuestRestServlet(hs).register(http_server)
+    ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
new file mode 100644
index 0000000000..a9baa2c1c3
--- /dev/null
+++ b/synapse/replication/http/send_event.py
@@ -0,0 +1,160 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import (
+    SynapseError, MatrixCodeMessageException, CodeMessageException,
+)
+from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.util.async import sleep
+from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.metrics import Measure
+from synapse.types import Requester, UserID
+
+import logging
+import re
+
+logger = logging.getLogger(__name__)
+
+
+@defer.inlineCallbacks
+def send_event_to_master(client, host, port, requester, event, context,
+                         ratelimit, extra_users):
+    """Send event to be handled on the master
+
+    Args:
+        client (SimpleHttpClient)
+        host (str): host of master
+        port (int): port on master listening for HTTP replication
+        requester (Requester)
+        event (FrozenEvent)
+        context (EventContext)
+        ratelimit (bool)
+        extra_users (list(UserID)): Any extra users to notify about event
+    """
+    uri = "http://%s:%s/_synapse/replication/send_event/%s" % (
+        host, port, event.event_id,
+    )
+
+    payload = {
+        "event": event.get_pdu_json(),
+        "internal_metadata": event.internal_metadata.get_dict(),
+        "rejected_reason": event.rejected_reason,
+        "context": context.serialize(event),
+        "requester": requester.serialize(),
+        "ratelimit": ratelimit,
+        "extra_users": [u.to_string() for u in extra_users],
+    }
+
+    try:
+        # We keep retrying the same request for timeouts. This is so that we
+        # have a good idea that the request has either succeeded or failed on
+        # the master, and so whether we should clean up or not.
+        while True:
+            try:
+                result = yield client.put_json(uri, payload)
+                break
+            except CodeMessageException as e:
+                if e.code != 504:
+                    raise
+
+            logger.warn("send_event request timed out")
+
+            # If we timed out we probably don't need to worry about backing
+            # off too much, but lets just wait a little anyway.
+            yield sleep(1)
+    except MatrixCodeMessageException as e:
+        # We convert to SynapseError as we know that it was a SynapseError
+        # on the master process that we should send to the client. (And
+        # importantly, not stack traces everywhere)
+        raise SynapseError(e.code, e.msg, e.errcode)
+    defer.returnValue(result)
+
+
+class ReplicationSendEventRestServlet(RestServlet):
+    """Handles events newly created on workers, including persisting and
+    notifying.
+
+    The API looks like:
+
+        POST /_synapse/replication/send_event/:event_id
+
+        {
+            "event": { .. serialized event .. },
+            "internal_metadata": { .. serialized internal_metadata .. },
+            "rejected_reason": ..,   // The event.rejected_reason field
+            "context": { .. serialized event context .. },
+            "requester": { .. serialized requester .. },
+            "ratelimit": true,
+            "extra_users": [],
+        }
+    """
+    PATTERNS = [re.compile("^/_synapse/replication/send_event/(?P<event_id>[^/]+)$")]
+
+    def __init__(self, hs):
+        super(ReplicationSendEventRestServlet, self).__init__()
+
+        self.event_creation_handler = hs.get_event_creation_handler()
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+
+        # The responses are tiny, so we may as well cache them for a while
+        self.response_cache = ResponseCache(hs, "send_event", timeout_ms=30 * 60 * 1000)
+
+    def on_PUT(self, request, event_id):
+        return self.response_cache.wrap(
+            event_id,
+            self._handle_request,
+            request
+        )
+
+    @defer.inlineCallbacks
+    def _handle_request(self, request):
+        with Measure(self.clock, "repl_send_event_parse"):
+            content = parse_json_object_from_request(request)
+
+            event_dict = content["event"]
+            internal_metadata = content["internal_metadata"]
+            rejected_reason = content["rejected_reason"]
+            event = FrozenEvent(event_dict, internal_metadata, rejected_reason)
+
+            requester = Requester.deserialize(self.store, content["requester"])
+            context = yield EventContext.deserialize(self.store, content["context"])
+
+            ratelimit = content["ratelimit"]
+            extra_users = [UserID.from_string(u) for u in content["extra_users"]]
+
+        if requester.user:
+            request.authenticated_entity = requester.user.to_string()
+
+        logger.info(
+            "Got event to send with ID: %s into room: %s",
+            event.event_id, event.room_id,
+        )
+
+        yield self.event_creation_handler.persist_and_notify_client_event(
+            requester, event, context,
+            ratelimit=ratelimit,
+            extra_users=extra_users,
+        )
+
+        defer.returnValue((200, {}))
+
+
+def register_servlets(hs, http_server):
+    ReplicationSendEventRestServlet(hs).register(http_server)
diff --git a/synapse/replication/presence_resource.py b/synapse/replication/presence_resource.py
deleted file mode 100644
index fc18130ab4..0000000000
--- a/synapse/replication/presence_resource.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.http.server import respond_with_json_bytes, request_handler
-from synapse.http.servlet import parse_json_object_from_request
-
-from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
-from twisted.internet import defer
-
-
-class PresenceResource(Resource):
-    """
-    HTTP endpoint for marking users as syncing.
-
-    POST /_synapse/replication/presence HTTP/1.1
-    Content-Type: application/json
-
-    {
-        "process_id": "<process_id>",
-        "syncing_users": ["<user_id>"]
-    }
-    """
-
-    def __init__(self, hs):
-        Resource.__init__(self)  # Resource is old-style, so no super()
-
-        self.version_string = hs.version_string
-        self.clock = hs.get_clock()
-        self.presence_handler = hs.get_presence_handler()
-
-    def render_POST(self, request):
-        self._async_render_POST(request)
-        return NOT_DONE_YET
-
-    @request_handler()
-    @defer.inlineCallbacks
-    def _async_render_POST(self, request):
-        content = parse_json_object_from_request(request)
-
-        process_id = content["process_id"]
-        syncing_user_ids = content["syncing_users"]
-
-        yield self.presence_handler.update_external_syncs(
-            process_id, set(syncing_user_ids)
-        )
-
-        respond_with_json_bytes(request, 200, "{}")
diff --git a/synapse/replication/pusher_resource.py b/synapse/replication/pusher_resource.py
deleted file mode 100644
index 9b01ab3c13..0000000000
--- a/synapse/replication/pusher_resource.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.http.server import respond_with_json_bytes, request_handler
-from synapse.http.servlet import parse_json_object_from_request
-
-from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
-from twisted.internet import defer
-
-
-class PusherResource(Resource):
-    """
-    HTTP endpoint for deleting rejected pushers
-    """
-
-    def __init__(self, hs):
-        Resource.__init__(self)  # Resource is old-style, so no super()
-
-        self.version_string = hs.version_string
-        self.store = hs.get_datastore()
-        self.notifier = hs.get_notifier()
-        self.clock = hs.get_clock()
-
-    def render_POST(self, request):
-        self._async_render_POST(request)
-        return NOT_DONE_YET
-
-    @request_handler()
-    @defer.inlineCallbacks
-    def _async_render_POST(self, request):
-        content = parse_json_object_from_request(request)
-
-        for remove in content["remove"]:
-            yield self.store.delete_pusher_by_app_id_pushkey_user_id(
-                remove["app_id"],
-                remove["push_key"],
-                remove["user_id"],
-            )
-
-        self.notifier.on_new_replication_data()
-
-        respond_with_json_bytes(request, 200, "{}")
diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py
deleted file mode 100644
index d8eb14592b..0000000000
--- a/synapse/replication/resource.py
+++ /dev/null
@@ -1,576 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.http.servlet import parse_integer, parse_string
-from synapse.http.server import request_handler, finish_request
-from synapse.replication.pusher_resource import PusherResource
-from synapse.replication.presence_resource import PresenceResource
-from synapse.replication.expire_cache import ExpireCacheResource
-from synapse.api.errors import SynapseError
-
-from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
-from twisted.internet import defer
-
-import ujson as json
-
-import collections
-import logging
-
-logger = logging.getLogger(__name__)
-
-REPLICATION_PREFIX = "/_synapse/replication"
-
-STREAM_NAMES = (
-    ("events",),
-    ("presence",),
-    ("typing",),
-    ("receipts",),
-    ("user_account_data", "room_account_data", "tag_account_data",),
-    ("backfill",),
-    ("push_rules",),
-    ("pushers",),
-    ("caches",),
-    ("to_device",),
-    ("public_rooms",),
-    ("federation",),
-    ("device_lists",),
-)
-
-
-class ReplicationResource(Resource):
-    """
-    HTTP endpoint for extracting data from synapse.
-
-    The streams of data returned by the endpoint are controlled by the
-    parameters given to the API. To return a given stream pass a query
-    parameter with a position in the stream to return data from or the
-    special value "-1" to return data from the start of the stream.
-
-    If there is no data for any of the supplied streams after the given
-    position then the request will block until there is data for one
-    of the streams. This allows clients to long-poll this API.
-
-    The possible streams are:
-
-    * "streams": A special stream returing the positions of other streams.
-    * "events": The new events seen on the server.
-    * "presence": Presence updates.
-    * "typing": Typing updates.
-    * "receipts": Receipt updates.
-    * "user_account_data": Top-level per user account data.
-    * "room_account_data: Per room per user account data.
-    * "tag_account_data": Per room per user tags.
-    * "backfill": Old events that have been backfilled from other servers.
-    * "push_rules": Per user changes to push rules.
-    * "pushers": Per user changes to their pushers.
-    * "caches": Cache invalidations.
-
-    The API takes two additional query parameters:
-
-    * "timeout": How long to wait before returning an empty response.
-    * "limit": The maximum number of rows to return for the selected streams.
-
-    The response is a JSON object with keys for each stream with updates. Under
-    each key is a JSON object with:
-
-    * "position": The current position of the stream.
-    * "field_names": The names of the fields in each row.
-    * "rows": The updates as an array of arrays.
-
-    There are a number of ways this API could be used:
-
-    1) To replicate the contents of the backing database to another database.
-    2) To be notified when the contents of a shared backing database changes.
-    3) To "tail" the activity happening on a server for debugging.
-
-    In the first case the client would track all of the streams and store it's
-    own copy of the data.
-
-    In the second case the client might theoretically just be able to follow
-    the "streams" stream to track where the other streams are. However in
-    practise it will probably need to get the contents of the streams in
-    order to expire the any in-memory caches. Whether it gets the contents
-    of the streams from this replication API or directly from the backing
-    store is a matter of taste.
-
-    In the third case the client would use the "streams" stream to find what
-    streams are available and their current positions. Then it can start
-    long-polling this replication API for new data on those streams.
-    """
-
-    def __init__(self, hs):
-        Resource.__init__(self)  # Resource is old-style, so no super()
-
-        self.version_string = hs.version_string
-        self.store = hs.get_datastore()
-        self.sources = hs.get_event_sources()
-        self.presence_handler = hs.get_presence_handler()
-        self.typing_handler = hs.get_typing_handler()
-        self.federation_sender = hs.get_federation_sender()
-        self.notifier = hs.notifier
-        self.clock = hs.get_clock()
-        self.config = hs.get_config()
-
-        self.putChild("remove_pushers", PusherResource(hs))
-        self.putChild("syncing_users", PresenceResource(hs))
-        self.putChild("expire_cache", ExpireCacheResource(hs))
-
-    def render_GET(self, request):
-        self._async_render_GET(request)
-        return NOT_DONE_YET
-
-    @defer.inlineCallbacks
-    def current_replication_token(self):
-        stream_token = yield self.sources.get_current_token()
-        backfill_token = yield self.store.get_current_backfill_token()
-        push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
-        pushers_token = self.store.get_pushers_stream_token()
-        caches_token = self.store.get_cache_stream_token()
-        public_rooms_token = self.store.get_current_public_room_stream_id()
-        federation_token = self.federation_sender.get_current_token()
-        device_list_token = self.store.get_device_stream_token()
-
-        defer.returnValue(_ReplicationToken(
-            room_stream_token,
-            int(stream_token.presence_key),
-            int(stream_token.typing_key),
-            int(stream_token.receipt_key),
-            int(stream_token.account_data_key),
-            backfill_token,
-            push_rules_token,
-            pushers_token,
-            0,  # State stream is no longer a thing
-            caches_token,
-            int(stream_token.to_device_key),
-            int(public_rooms_token),
-            int(federation_token),
-            int(device_list_token),
-        ))
-
-    @request_handler()
-    @defer.inlineCallbacks
-    def _async_render_GET(self, request):
-        limit = parse_integer(request, "limit", 100)
-        timeout = parse_integer(request, "timeout", 10 * 1000)
-
-        request.setHeader(b"Content-Type", b"application/json")
-
-        request_streams = {
-            name: parse_integer(request, name)
-            for names in STREAM_NAMES for name in names
-        }
-        request_streams["streams"] = parse_string(request, "streams")
-
-        federation_ack = parse_integer(request, "federation_ack", None)
-
-        def replicate():
-            return self.replicate(
-                request_streams, limit,
-                federation_ack=federation_ack
-            )
-
-        writer = yield self.notifier.wait_for_replication(replicate, timeout)
-        result = writer.finish()
-
-        for stream_name, stream_content in result.items():
-            logger.info(
-                "Replicating %d rows of %s from %s -> %s",
-                len(stream_content["rows"]),
-                stream_name,
-                request_streams.get(stream_name),
-                stream_content["position"],
-            )
-
-        request.write(json.dumps(result, ensure_ascii=False))
-        finish_request(request)
-
-    @defer.inlineCallbacks
-    def replicate(self, request_streams, limit, federation_ack=None):
-        writer = _Writer()
-        current_token = yield self.current_replication_token()
-        logger.debug("Replicating up to %r", current_token)
-
-        if limit == 0:
-            raise SynapseError(400, "Limit cannot be 0")
-
-        yield self.account_data(writer, current_token, limit, request_streams)
-        yield self.events(writer, current_token, limit, request_streams)
-        # TODO: implement limit
-        yield self.presence(writer, current_token, request_streams)
-        yield self.typing(writer, current_token, request_streams)
-        yield self.receipts(writer, current_token, limit, request_streams)
-        yield self.push_rules(writer, current_token, limit, request_streams)
-        yield self.pushers(writer, current_token, limit, request_streams)
-        yield self.caches(writer, current_token, limit, request_streams)
-        yield self.to_device(writer, current_token, limit, request_streams)
-        yield self.public_rooms(writer, current_token, limit, request_streams)
-        yield self.device_lists(writer, current_token, limit, request_streams)
-        self.federation(writer, current_token, limit, request_streams, federation_ack)
-        self.streams(writer, current_token, request_streams)
-
-        logger.debug("Replicated %d rows", writer.total)
-        defer.returnValue(writer)
-
-    def streams(self, writer, current_token, request_streams):
-        request_token = request_streams.get("streams")
-
-        streams = []
-
-        if request_token is not None:
-            if request_token == "-1":
-                for names, position in zip(STREAM_NAMES, current_token):
-                    streams.extend((name, position) for name in names)
-            else:
-                items = zip(
-                    STREAM_NAMES,
-                    current_token,
-                    _ReplicationToken(request_token)
-                )
-                for names, current_id, last_id in items:
-                    if last_id < current_id:
-                        streams.extend((name, current_id) for name in names)
-
-            if streams:
-                writer.write_header_and_rows(
-                    "streams", streams, ("name", "position"),
-                    position=str(current_token)
-                )
-
-    @defer.inlineCallbacks
-    def events(self, writer, current_token, limit, request_streams):
-        request_events = request_streams.get("events")
-        request_backfill = request_streams.get("backfill")
-
-        if request_events is not None or request_backfill is not None:
-            if request_events is None:
-                request_events = current_token.events
-            if request_backfill is None:
-                request_backfill = current_token.backfill
-
-            no_new_tokens = (
-                request_events == current_token.events
-                and request_backfill == current_token.backfill
-            )
-            if no_new_tokens:
-                return
-
-            res = yield self.store.get_all_new_events(
-                request_backfill, request_events,
-                current_token.backfill, current_token.events,
-                limit
-            )
-
-            upto_events_token = _position_from_rows(
-                res.new_forward_events, current_token.events
-            )
-
-            upto_backfill_token = _position_from_rows(
-                res.new_backfill_events, current_token.backfill
-            )
-
-            if request_events != upto_events_token:
-                writer.write_header_and_rows("events", res.new_forward_events, (
-                    "position", "internal", "json", "state_group"
-                ), position=upto_events_token)
-
-            if request_backfill != upto_backfill_token:
-                writer.write_header_and_rows("backfill", res.new_backfill_events, (
-                    "position", "internal", "json", "state_group",
-                ), position=upto_backfill_token)
-
-            writer.write_header_and_rows(
-                "forward_ex_outliers", res.forward_ex_outliers,
-                ("position", "event_id", "state_group"),
-            )
-            writer.write_header_and_rows(
-                "backward_ex_outliers", res.backward_ex_outliers,
-                ("position", "event_id", "state_group"),
-            )
-
-    @defer.inlineCallbacks
-    def presence(self, writer, current_token, request_streams):
-        current_position = current_token.presence
-
-        request_presence = request_streams.get("presence")
-
-        if request_presence is not None and request_presence != current_position:
-            presence_rows = yield self.presence_handler.get_all_presence_updates(
-                request_presence, current_position
-            )
-            upto_token = _position_from_rows(presence_rows, current_position)
-            writer.write_header_and_rows("presence", presence_rows, (
-                "position", "user_id", "state", "last_active_ts",
-                "last_federation_update_ts", "last_user_sync_ts",
-                "status_msg", "currently_active",
-            ), position=upto_token)
-
-    @defer.inlineCallbacks
-    def typing(self, writer, current_token, request_streams):
-        current_position = current_token.typing
-
-        request_typing = request_streams.get("typing")
-
-        if request_typing is not None and request_typing != current_position:
-            # If they have a higher token than current max, we can assume that
-            # they had been talking to a previous instance of the master. Since
-            # we reset the token on restart, the best (but hacky) thing we can
-            # do is to simply resend down all the typing notifications.
-            if request_typing > current_position:
-                request_typing = 0
-
-            typing_rows = yield self.typing_handler.get_all_typing_updates(
-                request_typing, current_position
-            )
-            upto_token = _position_from_rows(typing_rows, current_position)
-            writer.write_header_and_rows("typing", typing_rows, (
-                "position", "room_id", "typing"
-            ), position=upto_token)
-
-    @defer.inlineCallbacks
-    def receipts(self, writer, current_token, limit, request_streams):
-        current_position = current_token.receipts
-
-        request_receipts = request_streams.get("receipts")
-
-        if request_receipts is not None and request_receipts != current_position:
-            receipts_rows = yield self.store.get_all_updated_receipts(
-                request_receipts, current_position, limit
-            )
-            upto_token = _position_from_rows(receipts_rows, current_position)
-            writer.write_header_and_rows("receipts", receipts_rows, (
-                "position", "room_id", "receipt_type", "user_id", "event_id", "data"
-            ), position=upto_token)
-
-    @defer.inlineCallbacks
-    def account_data(self, writer, current_token, limit, request_streams):
-        current_position = current_token.account_data
-
-        user_account_data = request_streams.get("user_account_data")
-        room_account_data = request_streams.get("room_account_data")
-        tag_account_data = request_streams.get("tag_account_data")
-
-        if user_account_data is not None or room_account_data is not None:
-            if user_account_data is None:
-                user_account_data = current_position
-            if room_account_data is None:
-                room_account_data = current_position
-
-            no_new_tokens = (
-                user_account_data == current_position
-                and room_account_data == current_position
-            )
-            if no_new_tokens:
-                return
-
-            user_rows, room_rows = yield self.store.get_all_updated_account_data(
-                user_account_data, room_account_data, current_position, limit
-            )
-
-            upto_users_token = _position_from_rows(user_rows, current_position)
-            upto_rooms_token = _position_from_rows(room_rows, current_position)
-
-            writer.write_header_and_rows("user_account_data", user_rows, (
-                "position", "user_id", "type", "content"
-            ), position=upto_users_token)
-            writer.write_header_and_rows("room_account_data", room_rows, (
-                "position", "user_id", "room_id", "type", "content"
-            ), position=upto_rooms_token)
-
-        if tag_account_data is not None:
-            tag_rows = yield self.store.get_all_updated_tags(
-                tag_account_data, current_position, limit
-            )
-            upto_tag_token = _position_from_rows(tag_rows, current_position)
-            writer.write_header_and_rows("tag_account_data", tag_rows, (
-                "position", "user_id", "room_id", "tags"
-            ), position=upto_tag_token)
-
-    @defer.inlineCallbacks
-    def push_rules(self, writer, current_token, limit, request_streams):
-        current_position = current_token.push_rules
-
-        push_rules = request_streams.get("push_rules")
-
-        if push_rules is not None and push_rules != current_position:
-            rows = yield self.store.get_all_push_rule_updates(
-                push_rules, current_position, limit
-            )
-            upto_token = _position_from_rows(rows, current_position)
-            writer.write_header_and_rows("push_rules", rows, (
-                "position", "event_stream_ordering", "user_id", "rule_id", "op",
-                "priority_class", "priority", "conditions", "actions"
-            ), position=upto_token)
-
-    @defer.inlineCallbacks
-    def pushers(self, writer, current_token, limit, request_streams):
-        current_position = current_token.pushers
-
-        pushers = request_streams.get("pushers")
-
-        if pushers is not None and pushers != current_position:
-            updated, deleted = yield self.store.get_all_updated_pushers(
-                pushers, current_position, limit
-            )
-            upto_token = _position_from_rows(updated, current_position)
-            writer.write_header_and_rows("pushers", updated, (
-                "position", "user_id", "access_token", "profile_tag", "kind",
-                "app_id", "app_display_name", "device_display_name", "pushkey",
-                "ts", "lang", "data"
-            ), position=upto_token)
-            writer.write_header_and_rows("deleted_pushers", deleted, (
-                "position", "user_id", "app_id", "pushkey"
-            ), position=upto_token)
-
-    @defer.inlineCallbacks
-    def caches(self, writer, current_token, limit, request_streams):
-        current_position = current_token.caches
-
-        caches = request_streams.get("caches")
-
-        if caches is not None and caches != current_position:
-            updated_caches = yield self.store.get_all_updated_caches(
-                caches, current_position, limit
-            )
-            upto_token = _position_from_rows(updated_caches, current_position)
-            writer.write_header_and_rows("caches", updated_caches, (
-                "position", "cache_func", "keys", "invalidation_ts"
-            ), position=upto_token)
-
-    @defer.inlineCallbacks
-    def to_device(self, writer, current_token, limit, request_streams):
-        current_position = current_token.to_device
-
-        to_device = request_streams.get("to_device")
-
-        if to_device is not None and to_device != current_position:
-            to_device_rows = yield self.store.get_all_new_device_messages(
-                to_device, current_position, limit
-            )
-            upto_token = _position_from_rows(to_device_rows, current_position)
-            writer.write_header_and_rows("to_device", to_device_rows, (
-                "position", "user_id", "device_id", "message_json"
-            ), position=upto_token)
-
-    @defer.inlineCallbacks
-    def public_rooms(self, writer, current_token, limit, request_streams):
-        current_position = current_token.public_rooms
-
-        public_rooms = request_streams.get("public_rooms")
-
-        if public_rooms is not None and public_rooms != current_position:
-            public_rooms_rows = yield self.store.get_all_new_public_rooms(
-                public_rooms, current_position, limit
-            )
-            upto_token = _position_from_rows(public_rooms_rows, current_position)
-            writer.write_header_and_rows("public_rooms", public_rooms_rows, (
-                "position", "room_id", "visibility", "appservice_id", "network_id",
-            ), position=upto_token)
-
-    def federation(self, writer, current_token, limit, request_streams, federation_ack):
-        if self.config.send_federation:
-            return
-
-        current_position = current_token.federation
-
-        federation = request_streams.get("federation")
-
-        if federation is not None and federation != current_position:
-            federation_rows = self.federation_sender.get_replication_rows(
-                federation, limit, federation_ack=federation_ack,
-            )
-            upto_token = _position_from_rows(federation_rows, current_position)
-            writer.write_header_and_rows("federation", federation_rows, (
-                "position", "type", "content",
-            ), position=upto_token)
-
-    @defer.inlineCallbacks
-    def device_lists(self, writer, current_token, limit, request_streams):
-        current_position = current_token.device_lists
-
-        device_lists = request_streams.get("device_lists")
-
-        if device_lists is not None and device_lists != current_position:
-            changes = yield self.store.get_all_device_list_changes_for_remotes(
-                device_lists,
-            )
-            writer.write_header_and_rows("device_lists", changes, (
-                "position", "user_id", "destination",
-            ), position=current_position)
-
-
-class _Writer(object):
-    """Writes the streams as a JSON object as the response to the request"""
-    def __init__(self):
-        self.streams = {}
-        self.total = 0
-
-    def write_header_and_rows(self, name, rows, fields, position=None):
-        if position is None:
-            if rows:
-                position = rows[-1][0]
-            else:
-                return
-
-        self.streams[name] = {
-            "position": position if type(position) is int else str(position),
-            "field_names": fields,
-            "rows": rows,
-        }
-
-        self.total += len(rows)
-
-    def __nonzero__(self):
-        return bool(self.total)
-
-    def finish(self):
-        return self.streams
-
-
-class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
-    "events", "presence", "typing", "receipts", "account_data", "backfill",
-    "push_rules", "pushers", "state", "caches", "to_device", "public_rooms",
-    "federation", "device_lists",
-))):
-    __slots__ = []
-
-    def __new__(cls, *args):
-        if len(args) == 1:
-            streams = [int(value) for value in args[0].split("_")]
-            if len(streams) < len(cls._fields):
-                streams.extend([0] * (len(cls._fields) - len(streams)))
-            return cls(*streams)
-        else:
-            return super(_ReplicationToken, cls).__new__(cls, *args)
-
-    def __str__(self):
-        return "_".join(str(value) for value in self)
-
-
-def _position_from_rows(rows, current_position):
-    """Calculates a position to return for a stream. Ideally we want to return the
-    position of the last row, as that will be the most correct. However, if there
-    are no rows we fall back to using the current position to stop us from
-    repeatedly hitting the storage layer unncessarily thinking there are updates.
-    (Not all advances of the token correspond to an actual update)
-
-    We can't just always return the current position, as we often limit the
-    number of rows we replicate, and so the stream may lag. The assumption is
-    that if the storage layer returns no new rows then we are not lagging and
-    we are at the `current_position`.
-    """
-    if rows:
-        return rows[-1][0]
-    return current_position
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 18076e0f3b..61f5590c53 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -15,7 +15,6 @@
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.engines import PostgresEngine
-from twisted.internet import defer
 
 from ._slaved_id_tracker import SlavedIdTracker
 
@@ -26,7 +25,7 @@ logger = logging.getLogger(__name__)
 
 class BaseSlavedStore(SQLBaseStore):
     def __init__(self, db_conn, hs):
-        super(BaseSlavedStore, self).__init__(hs)
+        super(BaseSlavedStore, self).__init__(db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
             self._cache_id_gen = SlavedIdTracker(
                 db_conn, "cache_invalidation_stream", "stream_id",
@@ -34,8 +33,7 @@ class BaseSlavedStore(SQLBaseStore):
         else:
             self._cache_id_gen = None
 
-        self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache"
-        self.http_client = hs.get_simple_http_client()
+        self.hs = hs
 
     def stream_positions(self):
         pos = {}
@@ -43,33 +41,20 @@ class BaseSlavedStore(SQLBaseStore):
             pos["caches"] = self._cache_id_gen.get_current_token()
         return pos
 
-    def process_replication(self, result):
-        stream = result.get("caches")
-        if stream:
-            for row in stream["rows"]:
-                (
-                    position, cache_func, keys, invalidation_ts,
-                ) = row
-
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "caches":
+            self._cache_id_gen.advance(token)
+            for row in rows:
                 try:
-                    getattr(self, cache_func).invalidate(tuple(keys))
+                    getattr(self, row.cache_func).invalidate(tuple(row.keys))
                 except AttributeError:
-                    logger.info("Got unexpected cache_func: %r", cache_func)
-            self._cache_id_gen.advance(int(stream["position"]))
-        return defer.succeed(None)
+                    # We probably haven't pulled in the cache in this worker,
+                    # which is fine.
+                    pass
 
     def _invalidate_cache_and_stream(self, txn, cache_func, keys):
         txn.call_after(cache_func.invalidate, keys)
         txn.call_after(self._send_invalidation_poke, cache_func, keys)
 
-    @defer.inlineCallbacks
     def _send_invalidation_poke(self, cache_func, keys):
-        try:
-            yield self.http_client.post_json_get_json(self.expire_cache_url, {
-                "invalidate": [{
-                    "name": cache_func.__name__,
-                    "keys": list(keys),
-                }]
-            })
-        except:
-            logger.exception("Failed to poke on expire_cache")
+        self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 24b5c79d4a..9d1d173b2f 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -27,4 +27,9 @@ class SlavedIdTracker(object):
         self._current = (max if self.step > 0 else min)(self._current, new_id)
 
     def get_current_token(self):
+        """
+
+        Returns:
+            int
+        """
         return self._current
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 77c64722c7..d9ba6d69b1 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,50 +14,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
-from synapse.storage import DataStore
-from synapse.storage.account_data import AccountDataStore
-from synapse.storage.tags import TagsStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.storage.account_data import AccountDataWorkerStore
+from synapse.storage.tags import TagsWorkerStore
 
 
-class SlavedAccountDataStore(BaseSlavedStore):
+class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
 
     def __init__(self, db_conn, hs):
-        super(SlavedAccountDataStore, self).__init__(db_conn, hs)
         self._account_data_id_gen = SlavedIdTracker(
             db_conn, "account_data_max_stream_id", "stream_id",
         )
-        self._account_data_stream_cache = StreamChangeCache(
-            "AccountDataAndTagsChangeCache",
-            self._account_data_id_gen.get_current_token(),
-        )
-
-    get_account_data_for_user = (
-        AccountDataStore.__dict__["get_account_data_for_user"]
-    )
-
-    get_global_account_data_by_type_for_users = (
-        AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
-    )
 
-    get_global_account_data_by_type_for_user = (
-        AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
-    )
-
-    get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
-    get_tags_for_room = (
-        DataStore.get_tags_for_room.__func__
-    )
-    get_account_data_for_room = (
-        DataStore.get_account_data_for_room.__func__
-    )
-
-    get_updated_tags = DataStore.get_updated_tags.__func__
-    get_updated_account_data_for_user = (
-        DataStore.get_updated_account_data_for_user.__func__
-    )
+        super(SlavedAccountDataStore, self).__init__(db_conn, hs)
 
     def get_max_account_data_stream_id(self):
         return self._account_data_id_gen.get_current_token()
@@ -69,38 +40,29 @@ class SlavedAccountDataStore(BaseSlavedStore):
         result["tag_account_data"] = position
         return result
 
-    def process_replication(self, result):
-        stream = result.get("user_account_data")
-        if stream:
-            self._account_data_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                position, user_id, data_type = row[:3]
-                self.get_global_account_data_by_type_for_user.invalidate(
-                    (data_type, user_id,)
-                )
-                self.get_account_data_for_user.invalidate((user_id,))
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "tag_account_data":
+            self._account_data_id_gen.advance(token)
+            for row in rows:
+                self.get_tags_for_user.invalidate((row.user_id,))
                 self._account_data_stream_cache.entity_has_changed(
-                    user_id, position
+                    row.user_id, token
                 )
-
-        stream = result.get("room_account_data")
-        if stream:
-            self._account_data_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                position, user_id = row[:2]
-                self.get_account_data_for_user.invalidate((user_id,))
-                self._account_data_stream_cache.entity_has_changed(
-                    user_id, position
+        elif stream_name == "account_data":
+            self._account_data_id_gen.advance(token)
+            for row in rows:
+                if not row.room_id:
+                    self.get_global_account_data_by_type_for_user.invalidate(
+                        (row.data_type, row.user_id,)
+                    )
+                self.get_account_data_for_user.invalidate((row.user_id,))
+                self.get_account_data_for_room.invalidate((row.user_id, row.room_id,))
+                self.get_account_data_for_room_and_type.invalidate(
+                    (row.user_id, row.room_id, row.data_type,),
                 )
-
-        stream = result.get("tag_account_data")
-        if stream:
-            self._account_data_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                position, user_id = row[:2]
-                self.get_tags_for_user.invalidate((user_id,))
                 self._account_data_stream_cache.entity_has_changed(
-                    user_id, position
+                    row.user_id, token
                 )
-
-        return super(SlavedAccountDataStore, self).process_replication(result)
+        return super(SlavedAccountDataStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py
index a374f2f1a2..8cae3076f4 100644
--- a/synapse/replication/slave/storage/appservice.py
+++ b/synapse/replication/slave/storage/appservice.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,28 +14,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import BaseSlavedStore
-from synapse.storage import DataStore
-from synapse.config.appservice import load_appservices
+from synapse.storage.appservice import (
+    ApplicationServiceWorkerStore, ApplicationServiceTransactionWorkerStore,
+)
 
 
-class SlavedApplicationServiceStore(BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedApplicationServiceStore, self).__init__(db_conn, hs)
-        self.services_cache = load_appservices(
-            hs.config.server_name,
-            hs.config.app_service_config_files
-        )
-
-    get_app_service_by_token = DataStore.get_app_service_by_token.__func__
-    get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
-    get_app_services = DataStore.get_app_services.__func__
-    get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__
-    create_appservice_txn = DataStore.create_appservice_txn.__func__
-    get_appservices_by_state = DataStore.get_appservices_by_state.__func__
-    get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__
-    _get_last_txn = DataStore._get_last_txn.__func__
-    complete_appservice_txn = DataStore.complete_appservice_txn.__func__
-    get_appservice_state = DataStore.get_appservice_state.__func__
-    set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
-    set_appservice_state = DataStore.set_appservice_state.__func__
+class SlavedApplicationServiceStore(ApplicationServiceTransactionWorkerStore,
+                                    ApplicationServiceWorkerStore):
+    pass
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
new file mode 100644
index 0000000000..352c9a2aa8
--- /dev/null
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -0,0 +1,47 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStore
+from synapse.storage.client_ips import LAST_SEEN_GRANULARITY
+from synapse.util.caches import CACHE_SIZE_FACTOR
+from synapse.util.caches.descriptors import Cache
+
+
+class SlavedClientIpStore(BaseSlavedStore):
+    def __init__(self, db_conn, hs):
+        super(SlavedClientIpStore, self).__init__(db_conn, hs)
+
+        self.client_ip_last_seen = Cache(
+            name="client_ip_last_seen",
+            keylen=4,
+            max_entries=50000 * CACHE_SIZE_FACTOR,
+        )
+
+    def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
+        now = int(self._clock.time_msec())
+        key = (user_id, access_token, ip)
+
+        try:
+            last_seen = self.client_ip_last_seen.get(key)
+        except KeyError:
+            last_seen = None
+
+        # Rate-limited inserts
+        if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
+            return
+
+        self.hs.get_tcp_replication().send_user_ip(
+            user_id, access_token, ip, user_agent, device_id, now
+        )
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index cc860f9f9b..6f3fb64770 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -17,6 +17,7 @@ from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 from synapse.storage import DataStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.caches.expiringcache import ExpiringCache
 
 
 class SlavedDeviceInboxStore(BaseSlavedStore):
@@ -34,6 +35,13 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
             self._device_inbox_id_gen.get_current_token()
         )
 
+        self._last_device_delete_cache = ExpiringCache(
+            cache_name="last_device_delete_cache",
+            clock=self._clock,
+            max_len=10000,
+            expiry_ms=30 * 60 * 1000,
+        )
+
     get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
     get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
     get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__
@@ -45,21 +53,18 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
         result["to_device"] = self._device_inbox_id_gen.get_current_token()
         return result
 
-    def process_replication(self, result):
-        stream = result.get("to_device")
-        if stream:
-            self._device_inbox_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                stream_id = row[0]
-                entity = row[1]
-
-                if entity.startswith("@"):
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "to_device":
+            self._device_inbox_id_gen.advance(token)
+            for row in rows:
+                if row.entity.startswith("@"):
                     self._device_inbox_stream_cache.entity_has_changed(
-                        entity, stream_id
+                        row.entity, token
                     )
                 else:
                     self._device_federation_outbox_stream_cache.entity_has_changed(
-                        entity, stream_id
+                        row.entity, token
                     )
-
-        return super(SlavedDeviceInboxStore, self).process_replication(result)
+        return super(SlavedDeviceInboxStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index ca46aa17b6..7687867aee 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -16,6 +16,7 @@
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 from synapse.storage import DataStore
+from synapse.storage.end_to_end_keys import EndToEndKeyStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
@@ -45,28 +46,25 @@ class SlavedDeviceStore(BaseSlavedStore):
     _mark_as_sent_devices_by_remote_txn = (
         DataStore._mark_as_sent_devices_by_remote_txn.__func__
     )
+    count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"]
 
     def stream_positions(self):
         result = super(SlavedDeviceStore, self).stream_positions()
         result["device_lists"] = self._device_list_id_gen.get_current_token()
         return result
 
-    def process_replication(self, result):
-        stream = result.get("device_lists")
-        if stream:
-            self._device_list_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                stream_id = row[0]
-                user_id = row[1]
-                destination = row[2]
-
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "device_lists":
+            self._device_list_id_gen.advance(token)
+            for row in rows:
                 self._device_list_stream_cache.entity_has_changed(
-                    user_id, stream_id
+                    row.user_id, token
                 )
 
-                if destination:
+                if row.destination:
                     self._device_list_federation_stream_cache.entity_has_changed(
-                        destination, stream_id
+                        row.destination, token
                     )
-
-        return super(SlavedDeviceStore, self).process_replication(result)
+        return super(SlavedDeviceStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py
index 7301d885f2..6deecd3963 100644
--- a/synapse/replication/slave/storage/directory.py
+++ b/synapse/replication/slave/storage/directory.py
@@ -14,10 +14,8 @@
 # limitations under the License.
 
 from ._base import BaseSlavedStore
-from synapse.storage.directory import DirectoryStore
+from synapse.storage.directory import DirectoryWorkerStore
 
 
-class DirectoryStore(BaseSlavedStore):
-    get_aliases_for_room = DirectoryStore.__dict__[
-        "get_aliases_for_room"
-    ]
+class DirectoryStore(DirectoryWorkerStore, BaseSlavedStore):
+    pass
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index d72ff6055c..b1f64ef0d8 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,22 +13,18 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
-
-from synapse.api.constants import EventTypes
-from synapse.events import FrozenEvent
-from synapse.storage import DataStore
-from synapse.storage.roommember import RoomMemberStore
-from synapse.storage.event_federation import EventFederationStore
-from synapse.storage.event_push_actions import EventPushActionsStore
-from synapse.storage.state import StateStore
-from synapse.storage.stream import StreamStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
-
-import ujson as json
 import logging
 
+from synapse.api.constants import EventTypes
+from synapse.storage.event_federation import EventFederationWorkerStore
+from synapse.storage.event_push_actions import EventPushActionsWorkerStore
+from synapse.storage.events_worker import EventsWorkerStore
+from synapse.storage.roommember import RoomMemberWorkerStore
+from synapse.storage.state import StateGroupWorkerStore
+from synapse.storage.stream import StreamWorkerStore
+from synapse.storage.signatures import SignatureWorkerStore
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
 
 logger = logging.getLogger(__name__)
 
@@ -41,152 +38,33 @@ logger = logging.getLogger(__name__)
 # the method descriptor on the DataStore and chuck them into our class.
 
 
-class SlavedEventStore(BaseSlavedStore):
+class SlavedEventStore(EventFederationWorkerStore,
+                       RoomMemberWorkerStore,
+                       EventPushActionsWorkerStore,
+                       StreamWorkerStore,
+                       EventsWorkerStore,
+                       StateGroupWorkerStore,
+                       SignatureWorkerStore,
+                       BaseSlavedStore):
 
     def __init__(self, db_conn, hs):
-        super(SlavedEventStore, self).__init__(db_conn, hs)
         self._stream_id_gen = SlavedIdTracker(
             db_conn, "events", "stream_ordering",
         )
         self._backfill_id_gen = SlavedIdTracker(
             db_conn, "events", "stream_ordering", step=-1
         )
-        events_max = self._stream_id_gen.get_current_token()
-        event_cache_prefill, min_event_val = self._get_cache_dict(
-            db_conn, "events",
-            entity_column="room_id",
-            stream_column="stream_ordering",
-            max_value=events_max,
-        )
-        self._events_stream_cache = StreamChangeCache(
-            "EventsRoomStreamChangeCache", min_event_val,
-            prefilled_cache=event_cache_prefill,
-        )
-        self._membership_stream_cache = StreamChangeCache(
-            "MembershipStreamChangeCache", events_max,
-        )
 
-        self.stream_ordering_month_ago = 0
-        self._stream_order_on_start = self.get_room_max_stream_ordering()
+        super(SlavedEventStore, self).__init__(db_conn, hs)
 
     # Cached functions can't be accessed through a class instance so we need
     # to reach inside the __dict__ to extract them.
-    get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
-    get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
-    get_users_who_share_room_with_user = (
-        RoomMemberStore.__dict__["get_users_who_share_room_with_user"]
-    )
-    get_latest_event_ids_in_room = EventFederationStore.__dict__[
-        "get_latest_event_ids_in_room"
-    ]
-    get_invited_rooms_for_user = RoomMemberStore.__dict__[
-        "get_invited_rooms_for_user"
-    ]
-    get_unread_event_push_actions_by_room_for_user = (
-        EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"]
-    )
-    _get_state_group_for_events = (
-        StateStore.__dict__["_get_state_group_for_events"]
-    )
-    _get_state_group_for_event = (
-        StateStore.__dict__["_get_state_group_for_event"]
-    )
-    _get_state_groups_from_groups = (
-        StateStore.__dict__["_get_state_groups_from_groups"]
-    )
-    _get_state_groups_from_groups_txn = (
-        DataStore._get_state_groups_from_groups_txn.__func__
-    )
-    _get_state_group_from_group = (
-        StateStore.__dict__["_get_state_group_from_group"]
-    )
-    get_recent_event_ids_for_room = (
-        StreamStore.__dict__["get_recent_event_ids_for_room"]
-    )
-
-    get_unread_push_actions_for_user_in_range_for_http = (
-        DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
-    )
-    get_unread_push_actions_for_user_in_range_for_email = (
-        DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__
-    )
-    get_push_action_users_in_range = (
-        DataStore.get_push_action_users_in_range.__func__
-    )
-    get_event = DataStore.get_event.__func__
-    get_events = DataStore.get_events.__func__
-    get_rooms_for_user_where_membership_is = (
-        DataStore.get_rooms_for_user_where_membership_is.__func__
-    )
-    get_membership_changes_for_user = (
-        DataStore.get_membership_changes_for_user.__func__
-    )
-    get_room_events_max_id = DataStore.get_room_events_max_id.__func__
-    get_room_events_stream_for_room = (
-        DataStore.get_room_events_stream_for_room.__func__
-    )
-    get_events_around = DataStore.get_events_around.__func__
-    get_state_for_event = DataStore.get_state_for_event.__func__
-    get_state_for_events = DataStore.get_state_for_events.__func__
-    get_state_groups = DataStore.get_state_groups.__func__
-    get_state_groups_ids = DataStore.get_state_groups_ids.__func__
-    get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
-    get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
-    get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
-    get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
-    _get_joined_users_from_context = (
-        RoomMemberStore.__dict__["_get_joined_users_from_context"]
-    )
-
-    get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
-    get_room_events_stream_for_rooms = (
-        DataStore.get_room_events_stream_for_rooms.__func__
-    )
-    is_host_joined = DataStore.is_host_joined.__func__
-    _is_host_joined = RoomMemberStore.__dict__["_is_host_joined"]
-    get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
-
-    _set_before_and_after = staticmethod(DataStore._set_before_and_after)
-
-    _get_events = DataStore._get_events.__func__
-    _get_events_from_cache = DataStore._get_events_from_cache.__func__
 
-    _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__
-    _enqueue_events = DataStore._enqueue_events.__func__
-    _do_fetch = DataStore._do_fetch.__func__
-    _fetch_event_rows = DataStore._fetch_event_rows.__func__
-    _get_event_from_row = DataStore._get_event_from_row.__func__
-    _get_rooms_for_user_where_membership_is_txn = (
-        DataStore._get_rooms_for_user_where_membership_is_txn.__func__
-    )
-    _get_members_rows_txn = DataStore._get_members_rows_txn.__func__
-    _get_state_for_groups = DataStore._get_state_for_groups.__func__
-    _get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
-    _get_events_around_txn = DataStore._get_events_around_txn.__func__
-    _get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
+    def get_room_max_stream_ordering(self):
+        return self._stream_id_gen.get_current_token()
 
-    get_backfill_events = DataStore.get_backfill_events.__func__
-    _get_backfill_events = DataStore._get_backfill_events.__func__
-    get_missing_events = DataStore.get_missing_events.__func__
-    _get_missing_events = DataStore._get_missing_events.__func__
-
-    get_auth_chain = DataStore.get_auth_chain.__func__
-    get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__
-    _get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__
-
-    get_room_max_stream_ordering = DataStore.get_room_max_stream_ordering.__func__
-
-    get_forward_extremeties_for_room = (
-        DataStore.get_forward_extremeties_for_room.__func__
-    )
-    _get_forward_extremeties_for_room = (
-        EventFederationStore.__dict__["_get_forward_extremeties_for_room"]
-    )
-
-    get_all_new_events_stream = DataStore.get_all_new_events_stream.__func__
-
-    get_federation_out_pos = DataStore.get_federation_out_pos.__func__
-    update_federation_out_pos = DataStore.update_federation_out_pos.__func__
+    def get_room_min_stream_ordering(self):
+        return self._backfill_id_gen.get_current_token()
 
     def stream_positions(self):
         result = super(SlavedEventStore, self).stream_positions()
@@ -194,84 +72,47 @@ class SlavedEventStore(BaseSlavedStore):
         result["backfill"] = -self._backfill_id_gen.get_current_token()
         return result
 
-    def process_replication(self, result):
-        stream = result.get("events")
-        if stream:
-            self._stream_id_gen.advance(int(stream["position"]))
-
-            if stream["rows"]:
-                logger.info("Got %d event rows", len(stream["rows"]))
-
-            for row in stream["rows"]:
-                self._process_replication_row(
-                    row, backfilled=False,
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "events":
+            self._stream_id_gen.advance(token)
+            for row in rows:
+                self.invalidate_caches_for_event(
+                    token, row.event_id, row.room_id, row.type, row.state_key,
+                    row.redacts,
+                    backfilled=False,
                 )
-
-        stream = result.get("backfill")
-        if stream:
-            self._backfill_id_gen.advance(-int(stream["position"]))
-            for row in stream["rows"]:
-                self._process_replication_row(
-                    row, backfilled=True,
+        elif stream_name == "backfill":
+            self._backfill_id_gen.advance(-token)
+            for row in rows:
+                self.invalidate_caches_for_event(
+                    -token, row.event_id, row.room_id, row.type, row.state_key,
+                    row.redacts,
+                    backfilled=True,
                 )
-
-        stream = result.get("forward_ex_outliers")
-        if stream:
-            self._stream_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                event_id = row[1]
-                self._invalidate_get_event_cache(event_id)
-
-        stream = result.get("backward_ex_outliers")
-        if stream:
-            self._backfill_id_gen.advance(-int(stream["position"]))
-            for row in stream["rows"]:
-                event_id = row[1]
-                self._invalidate_get_event_cache(event_id)
-
-        return super(SlavedEventStore, self).process_replication(result)
-
-    def _process_replication_row(self, row, backfilled):
-        internal = json.loads(row[1])
-        event_json = json.loads(row[2])
-        event = FrozenEvent(event_json, internal_metadata_dict=internal)
-        self.invalidate_caches_for_event(
-            event, backfilled,
+        return super(SlavedEventStore, self).process_replication_rows(
+            stream_name, token, rows
         )
 
-    def invalidate_caches_for_event(self, event, backfilled):
-        self._invalidate_get_event_cache(event.event_id)
+    def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
+                                    etype, state_key, redacts, backfilled):
+        self._invalidate_get_event_cache(event_id)
 
-        self.get_latest_event_ids_in_room.invalidate((event.room_id,))
+        self.get_latest_event_ids_in_room.invalidate((room_id,))
 
         self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
-            (event.room_id,)
+            (room_id,)
         )
 
         if not backfilled:
             self._events_stream_cache.entity_has_changed(
-                event.room_id, event.internal_metadata.stream_ordering
+                room_id, stream_ordering
             )
 
-        # self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
-        #     (event.room_id,)
-        # )
+        if redacts:
+            self._invalidate_get_event_cache(redacts)
 
-        if event.type == EventTypes.Redaction:
-            self._invalidate_get_event_cache(event.redacts)
-
-        if event.type == EventTypes.Member:
+        if etype == EventTypes.Member:
             self._membership_stream_cache.entity_has_changed(
-                event.state_key, event.internal_metadata.stream_ordering
+                state_key, stream_ordering
             )
-            self.get_invited_rooms_for_user.invalidate((event.state_key,))
-
-        if not event.is_state():
-            return
-
-        if backfilled:
-            return
-
-        if (not event.internal_metadata.is_invite_from_remote()
-                and event.internal_metadata.is_outlier()):
-            return
+            self.get_invited_rooms_for_user.invalidate((state_key,))
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
new file mode 100644
index 0000000000..0bc4bce5b0
--- /dev/null
+++ b/synapse/replication/slave/storage/groups.py
@@ -0,0 +1,54 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
+from synapse.storage import DataStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+
+class SlavedGroupServerStore(BaseSlavedStore):
+    def __init__(self, db_conn, hs):
+        super(SlavedGroupServerStore, self).__init__(db_conn, hs)
+
+        self.hs = hs
+
+        self._group_updates_id_gen = SlavedIdTracker(
+            db_conn, "local_group_updates", "stream_id",
+        )
+        self._group_updates_stream_cache = StreamChangeCache(
+            "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(),
+        )
+
+    get_groups_changes_for_user = DataStore.get_groups_changes_for_user.__func__
+    get_group_stream_token = DataStore.get_group_stream_token.__func__
+    get_all_groups_for_user = DataStore.get_all_groups_for_user.__func__
+
+    def stream_positions(self):
+        result = super(SlavedGroupServerStore, self).stream_positions()
+        result["groups"] = self._group_updates_id_gen.get_current_token()
+        return result
+
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "groups":
+            self._group_updates_id_gen.advance(token)
+            for row in rows:
+                self._group_updates_stream_cache.entity_has_changed(
+                    row.user_id, token
+                )
+
+        return super(SlavedGroupServerStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 40f6c9a386..cfb9280181 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -39,6 +39,16 @@ class SlavedPresenceStore(BaseSlavedStore):
     _get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"]
     get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"]
 
+    # XXX: This is a bit broken because we don't persist the accepted list in a
+    # way that can be replicated. This means that we don't have a way to
+    # invalidate the cache correctly.
+    get_presence_list_accepted = PresenceStore.__dict__[
+        "get_presence_list_accepted"
+    ]
+    get_presence_list_observers_accepted = PresenceStore.__dict__[
+        "get_presence_list_observers_accepted"
+    ]
+
     def get_current_presence_token(self):
         return self._presence_id_gen.get_current_token()
 
@@ -48,14 +58,14 @@ class SlavedPresenceStore(BaseSlavedStore):
         result["presence"] = position
         return result
 
-    def process_replication(self, result):
-        stream = result.get("presence")
-        if stream:
-            self._presence_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                position, user_id = row[:2]
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "presence":
+            self._presence_id_gen.advance(token)
+            for row in rows:
                 self.presence_stream_cache.entity_has_changed(
-                    user_id, position
+                    row.user_id, token
                 )
-
-        return super(SlavedPresenceStore, self).process_replication(result)
+                self._get_presence_for_user.invalidate((row.user_id,))
+        return super(SlavedPresenceStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py
new file mode 100644
index 0000000000..46c28d4171
--- /dev/null
+++ b/synapse/replication/slave/storage/profile.py
@@ -0,0 +1,21 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.storage.profile import ProfileWorkerStore
+
+
+class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore):
+    pass
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 21ceb0213a..bb2c40b6e3 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,29 +16,15 @@
 
 from .events import SlavedEventStore
 from ._slaved_id_tracker import SlavedIdTracker
-from synapse.storage import DataStore
-from synapse.storage.push_rule import PushRuleStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.storage.push_rule import PushRulesWorkerStore
 
 
-class SlavedPushRuleStore(SlavedEventStore):
+class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore):
     def __init__(self, db_conn, hs):
-        super(SlavedPushRuleStore, self).__init__(db_conn, hs)
         self._push_rules_stream_id_gen = SlavedIdTracker(
             db_conn, "push_rules_stream", "stream_id",
         )
-        self.push_rules_stream_cache = StreamChangeCache(
-            "PushRulesStreamChangeCache",
-            self._push_rules_stream_id_gen.get_current_token(),
-        )
-
-    get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
-    get_push_rules_enabled_for_user = (
-        PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
-    )
-    have_push_rules_changed_for_user = (
-        DataStore.have_push_rules_changed_for_user.__func__
-    )
+        super(SlavedPushRuleStore, self).__init__(db_conn, hs)
 
     def get_push_rules_stream_token(self):
         return (
@@ -45,23 +32,23 @@ class SlavedPushRuleStore(SlavedEventStore):
             self._stream_id_gen.get_current_token(),
         )
 
+    def get_max_push_rules_stream_id(self):
+        return self._push_rules_stream_id_gen.get_current_token()
+
     def stream_positions(self):
         result = super(SlavedPushRuleStore, self).stream_positions()
         result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
         return result
 
-    def process_replication(self, result):
-        stream = result.get("push_rules")
-        if stream:
-            for row in stream["rows"]:
-                position = row[0]
-                user_id = row[2]
-                self.get_push_rules_for_user.invalidate((user_id,))
-                self.get_push_rules_enabled_for_user.invalidate((user_id,))
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "push_rules":
+            self._push_rules_stream_id_gen.advance(token)
+            for row in rows:
+                self.get_push_rules_for_user.invalidate((row.user_id,))
+                self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
                 self.push_rules_stream_cache.entity_has_changed(
-                    user_id, position
+                    row.user_id, token
                 )
-
-            self._push_rules_stream_id_gen.advance(int(stream["position"]))
-
-        return super(SlavedPushRuleStore, self).process_replication(result)
+        return super(SlavedPushRuleStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index d88206b3bb..a7cd5a7291 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,10 +17,10 @@
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
-from synapse.storage import DataStore
+from synapse.storage.pusher import PusherWorkerStore
 
 
-class SlavedPusherStore(BaseSlavedStore):
+class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
 
     def __init__(self, db_conn, hs):
         super(SlavedPusherStore, self).__init__(db_conn, hs)
@@ -28,25 +29,14 @@ class SlavedPusherStore(BaseSlavedStore):
             extra_tables=[("deleted_pushers", "stream_id")],
         )
 
-    get_all_pushers = DataStore.get_all_pushers.__func__
-    get_pushers_by = DataStore.get_pushers_by.__func__
-    get_pushers_by_app_id_and_pushkey = (
-        DataStore.get_pushers_by_app_id_and_pushkey.__func__
-    )
-    _decode_pushers_rows = DataStore._decode_pushers_rows.__func__
-
     def stream_positions(self):
         result = super(SlavedPusherStore, self).stream_positions()
         result["pushers"] = self._pushers_id_gen.get_current_token()
         return result
 
-    def process_replication(self, result):
-        stream = result.get("pushers")
-        if stream:
-            self._pushers_id_gen.advance(int(stream["position"]))
-
-        stream = result.get("deleted_pushers")
-        if stream:
-            self._pushers_id_gen.advance(int(stream["position"]))
-
-        return super(SlavedPusherStore, self).process_replication(result)
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "pushers":
+            self._pushers_id_gen.advance(token)
+        return super(SlavedPusherStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index ac9662d399..1647072f65 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,9 +17,7 @@
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
-from synapse.storage import DataStore
-from synapse.storage.receipts import ReceiptsStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.storage.receipts import ReceiptsWorkerStore
 
 # So, um, we want to borrow a load of functions intended for reading from
 # a DataStore, but we don't want to take functions that either write to the
@@ -29,56 +28,43 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 # the method descriptor on the DataStore and chuck them into our class.
 
 
-class SlavedReceiptsStore(BaseSlavedStore):
+class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
 
     def __init__(self, db_conn, hs):
-        super(SlavedReceiptsStore, self).__init__(db_conn, hs)
-
+        # We instantiate this first as the ReceiptsWorkerStore constructor
+        # needs to be able to call get_max_receipt_stream_id
         self._receipts_id_gen = SlavedIdTracker(
             db_conn, "receipts_linearized", "stream_id"
         )
 
-        self._receipts_stream_cache = StreamChangeCache(
-            "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
-        )
-
-    get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
-    get_linearized_receipts_for_room = (
-        ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
-    )
-    _get_linearized_receipts_for_rooms = (
-        ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
-    )
-    get_last_receipt_event_id_for_user = (
-        ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
-    )
-
-    get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
-    get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
+        super(SlavedReceiptsStore, self).__init__(db_conn, hs)
 
-    get_linearized_receipts_for_rooms = (
-        DataStore.get_linearized_receipts_for_rooms.__func__
-    )
+    def get_max_receipt_stream_id(self):
+        return self._receipts_id_gen.get_current_token()
 
     def stream_positions(self):
         result = super(SlavedReceiptsStore, self).stream_positions()
         result["receipts"] = self._receipts_id_gen.get_current_token()
         return result
 
-    def process_replication(self, result):
-        stream = result.get("receipts")
-        if stream:
-            self._receipts_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                position, room_id, receipt_type, user_id = row[:4]
-                self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
-                self._receipts_stream_cache.entity_has_changed(room_id, position)
-
-        return super(SlavedReceiptsStore, self).process_replication(result)
-
     def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
         self.get_receipts_for_user.invalidate((user_id, receipt_type))
         self.get_linearized_receipts_for_room.invalidate_many((room_id,))
         self.get_last_receipt_event_id_for_user.invalidate(
             (user_id, room_id, receipt_type)
         )
+        self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
+        self.get_receipts_for_room.invalidate((room_id, receipt_type))
+
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "receipts":
+            self._receipts_id_gen.advance(token)
+            for row in rows:
+                self.invalidate_caches_for_receipt(
+                    row.room_id, row.receipt_type, row.user_id
+                )
+                self._receipts_stream_cache.entity_has_changed(row.room_id, token)
+
+        return super(SlavedReceiptsStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py
index e27c7332d2..7323bf0f1e 100644
--- a/synapse/replication/slave/storage/registration.py
+++ b/synapse/replication/slave/storage/registration.py
@@ -14,20 +14,8 @@
 # limitations under the License.
 
 from ._base import BaseSlavedStore
-from synapse.storage import DataStore
-from synapse.storage.registration import RegistrationStore
+from synapse.storage.registration import RegistrationWorkerStore
 
 
-class SlavedRegistrationStore(BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedRegistrationStore, self).__init__(db_conn, hs)
-
-    # TODO: use the cached version and invalidate deleted tokens
-    get_user_by_access_token = RegistrationStore.__dict__[
-        "get_user_by_access_token"
-    ]
-
-    _query_for_auth = DataStore._query_for_auth.__func__
-    get_user_by_id = RegistrationStore.__dict__[
-        "get_user_by_id"
-    ]
+class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore):
+    pass
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 6df9a25ef3..5ae1670157 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -14,41 +14,29 @@
 # limitations under the License.
 
 from ._base import BaseSlavedStore
-from synapse.storage import DataStore
-from synapse.storage.room import RoomStore
+from synapse.storage.room import RoomWorkerStore
 from ._slaved_id_tracker import SlavedIdTracker
 
 
-class RoomStore(BaseSlavedStore):
+class RoomStore(RoomWorkerStore, BaseSlavedStore):
     def __init__(self, db_conn, hs):
         super(RoomStore, self).__init__(db_conn, hs)
         self._public_room_id_gen = SlavedIdTracker(
             db_conn, "public_room_list_stream", "stream_id"
         )
 
-    get_public_room_ids = DataStore.get_public_room_ids.__func__
-    get_current_public_room_stream_id = (
-        DataStore.get_current_public_room_stream_id.__func__
-    )
-    get_public_room_ids_at_stream_id = (
-        RoomStore.__dict__["get_public_room_ids_at_stream_id"]
-    )
-    get_public_room_ids_at_stream_id_txn = (
-        DataStore.get_public_room_ids_at_stream_id_txn.__func__
-    )
-    get_published_at_stream_id_txn = (
-        DataStore.get_published_at_stream_id_txn.__func__
-    )
-    get_public_room_changes = DataStore.get_public_room_changes.__func__
+    def get_current_public_room_stream_id(self):
+        return self._public_room_id_gen.get_current_token()
 
     def stream_positions(self):
         result = super(RoomStore, self).stream_positions()
         result["public_rooms"] = self._public_room_id_gen.get_current_token()
         return result
 
-    def process_replication(self, result):
-        stream = result.get("public_rooms")
-        if stream:
-            self._public_room_id_gen.advance(int(stream["position"]))
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "public_rooms":
+            self._public_room_id_gen.advance(token)
 
-        return super(RoomStore, self).process_replication(result)
+        return super(RoomStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py
new file mode 100644
index 0000000000..81c2ea7ee9
--- /dev/null
+++ b/synapse/replication/tcp/__init__.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This module implements the TCP replication protocol used by synapse to
+communicate between the master process and its workers (when they're enabled).
+
+Further details can be found in docs/tcp_replication.rst
+
+
+Structure of the module:
+ * client.py   - the client classes used for workers to connect to master
+ * command.py  - the definitions of all the valid commands
+ * protocol.py - contains bot the client and server protocol implementations,
+                 these should not be used directly
+ * resource.py - the server classes that accepts and handle client connections
+ * streams.py  - the definitons of all the valid streams
+
+"""
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
new file mode 100644
index 0000000000..6d2513c4e2
--- /dev/null
+++ b/synapse/replication/tcp/client.py
@@ -0,0 +1,203 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A replication client for use by synapse workers.
+"""
+
+from twisted.internet import reactor, defer
+from twisted.internet.protocol import ReconnectingClientFactory
+
+from .commands import (
+    FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand,
+    UserIpCommand,
+)
+from .protocol import ClientReplicationStreamProtocol
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationClientFactory(ReconnectingClientFactory):
+    """Factory for building connections to the master. Will reconnect if the
+    connection is lost.
+
+    Accepts a handler that will be called when new data is available or data
+    is required.
+    """
+    maxDelay = 5  # Try at least once every N seconds
+
+    def __init__(self, hs, client_name, handler):
+        self.client_name = client_name
+        self.handler = handler
+        self.server_name = hs.config.server_name
+        self._clock = hs.get_clock()  # As self.clock is defined in super class
+
+        reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying)
+
+    def startedConnecting(self, connector):
+        logger.info("Connecting to replication: %r", connector.getDestination())
+
+    def buildProtocol(self, addr):
+        logger.info("Connected to replication: %r", addr)
+        self.resetDelay()
+        return ClientReplicationStreamProtocol(
+            self.client_name, self.server_name, self._clock, self.handler
+        )
+
+    def clientConnectionLost(self, connector, reason):
+        logger.error("Lost replication conn: %r", reason)
+        ReconnectingClientFactory.clientConnectionLost(self, connector, reason)
+
+    def clientConnectionFailed(self, connector, reason):
+        logger.error("Failed to connect to replication: %r", reason)
+        ReconnectingClientFactory.clientConnectionFailed(
+            self, connector, reason
+        )
+
+
+class ReplicationClientHandler(object):
+    """A base handler that can be passed to the ReplicationClientFactory.
+
+    By default proxies incoming replication data to the SlaveStore.
+    """
+    def __init__(self, store):
+        self.store = store
+
+        # The current connection. None if we are currently (re)connecting
+        self.connection = None
+
+        # Any pending commands to be sent once a new connection has been
+        # established
+        self.pending_commands = []
+
+        # Map from string -> deferred, to wake up when receiveing a SYNC with
+        # the given string.
+        # Used for tests.
+        self.awaiting_syncs = {}
+
+    def start_replication(self, hs):
+        """Helper method to start a replication connection to the remote server
+        using TCP.
+        """
+        client_name = hs.config.worker_name
+        factory = ReplicationClientFactory(hs, client_name, self)
+        host = hs.config.worker_replication_host
+        port = hs.config.worker_replication_port
+        reactor.connectTCP(host, port, factory)
+
+    def on_rdata(self, stream_name, token, rows):
+        """Called when we get new replication data. By default this just pokes
+        the slave store.
+
+        Can be overriden in subclasses to handle more.
+        """
+        logger.info("Received rdata %s -> %s", stream_name, token)
+        self.store.process_replication_rows(stream_name, token, rows)
+
+    def on_position(self, stream_name, token):
+        """Called when we get new position data. By default this just pokes
+        the slave store.
+
+        Can be overriden in subclasses to handle more.
+        """
+        self.store.process_replication_rows(stream_name, token, [])
+
+    def on_sync(self, data):
+        """When we received a SYNC we wake up any deferreds that were waiting
+        for the sync with the given data.
+
+        Used by tests.
+        """
+        d = self.awaiting_syncs.pop(data, None)
+        if d:
+            d.callback(data)
+
+    def get_streams_to_replicate(self):
+        """Called when a new connection has been established and we need to
+        subscribe to streams.
+
+        Returns a dictionary of stream name to token.
+        """
+        args = self.store.stream_positions()
+        user_account_data = args.pop("user_account_data", None)
+        room_account_data = args.pop("room_account_data", None)
+        if user_account_data:
+            args["account_data"] = user_account_data
+        elif room_account_data:
+            args["account_data"] = room_account_data
+        return args
+
+    def get_currently_syncing_users(self):
+        """Get the list of currently syncing users (if any). This is called
+        when a connection has been established and we need to send the
+        currently syncing users. (Overriden by the synchrotron's only)
+        """
+        return []
+
+    def send_command(self, cmd):
+        """Send a command to master (when we get establish a connection if we
+        don't have one already.)
+        """
+        if self.connection:
+            self.connection.send_command(cmd)
+        else:
+            logger.warn("Queuing command as not connected: %r", cmd.NAME)
+            self.pending_commands.append(cmd)
+
+    def send_federation_ack(self, token):
+        """Ack data for the federation stream. This allows the master to drop
+        data stored purely in memory.
+        """
+        self.send_command(FederationAckCommand(token))
+
+    def send_user_sync(self, user_id, is_syncing, last_sync_ms):
+        """Poke the master that a user has started/stopped syncing.
+        """
+        self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
+
+    def send_remove_pusher(self, app_id, push_key, user_id):
+        """Poke the master to remove a pusher for a user
+        """
+        cmd = RemovePusherCommand(app_id, push_key, user_id)
+        self.send_command(cmd)
+
+    def send_invalidate_cache(self, cache_func, keys):
+        """Poke the master to invalidate a cache.
+        """
+        cmd = InvalidateCacheCommand(cache_func.__name__, keys)
+        self.send_command(cmd)
+
+    def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+        """Tell the master that the user made a request.
+        """
+        cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
+        self.send_command(cmd)
+
+    def await_sync(self, data):
+        """Returns a deferred that is resolved when we receive a SYNC command
+        with given data.
+
+        Used by tests.
+        """
+        return self.awaiting_syncs.setdefault(data, defer.Deferred())
+
+    def update_connection(self, connection):
+        """Called when a connection has been established (or lost with None).
+        """
+        self.connection = connection
+        if connection:
+            for cmd in self.pending_commands:
+                connection.send_command(cmd)
+            self.pending_commands = []
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
new file mode 100644
index 0000000000..12aac3cc6b
--- /dev/null
+++ b/synapse/replication/tcp/commands.py
@@ -0,0 +1,386 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Defines the various valid commands
+
+The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are
+allowed to be sent by which side.
+"""
+
+import logging
+import simplejson
+
+
+logger = logging.getLogger(__name__)
+
+_json_encoder = simplejson.JSONEncoder(namedtuple_as_object=False)
+
+
+class Command(object):
+    """The base command class.
+
+    All subclasses must set the NAME variable which equates to the name of the
+    command on the wire.
+
+    A full command line on the wire is constructed from `NAME + " " + to_line()`
+
+    The default implementation creates a command of form `<NAME> <data>`
+    """
+    NAME = None
+
+    def __init__(self, data):
+        self.data = data
+
+    @classmethod
+    def from_line(cls, line):
+        """Deserialises a line from the wire into this command. `line` does not
+        include the command.
+        """
+        return cls(line)
+
+    def to_line(self):
+        """Serialises the comamnd for the wire. Does not include the command
+        prefix.
+        """
+        return self.data
+
+
+class ServerCommand(Command):
+    """Sent by the server on new connection and includes the server_name.
+
+    Format::
+
+        SERVER <server_name>
+    """
+    NAME = "SERVER"
+
+
+class RdataCommand(Command):
+    """Sent by server when a subscribed stream has an update.
+
+    Format::
+
+        RDATA <stream_name> <token> <row_json>
+
+    The `<token>` may either be a numeric stream id OR "batch". The latter case
+    is used to support sending multiple updates with the same stream ID. This
+    is done by sending an RDATA for each row, with all but the last RDATA having
+    a token of "batch" and the last having the final stream ID.
+
+    The client should batch all incoming RDATA with a token of "batch" (per
+    stream_name) until it sees an RDATA with a numeric stream ID.
+
+    `<token>` of "batch" maps to the instance variable `token` being None.
+
+    An example of a batched series of RDATA::
+
+        RDATA presence batch ["@foo:example.com", "online", ...]
+        RDATA presence batch ["@bar:example.com", "online", ...]
+        RDATA presence 59 ["@baz:example.com", "online", ...]
+    """
+    NAME = "RDATA"
+
+    def __init__(self, stream_name, token, row):
+        self.stream_name = stream_name
+        self.token = token
+        self.row = row
+
+    @classmethod
+    def from_line(cls, line):
+        stream_name, token, row_json = line.split(" ", 2)
+        return cls(
+            stream_name,
+            None if token == "batch" else int(token),
+            simplejson.loads(row_json)
+        )
+
+    def to_line(self):
+        return " ".join((
+            self.stream_name,
+            str(self.token) if self.token is not None else "batch",
+            _json_encoder.encode(self.row),
+        ))
+
+
+class PositionCommand(Command):
+    """Sent by the client to tell the client the stream postition without
+    needing to send an RDATA.
+    """
+    NAME = "POSITION"
+
+    def __init__(self, stream_name, token):
+        self.stream_name = stream_name
+        self.token = token
+
+    @classmethod
+    def from_line(cls, line):
+        stream_name, token = line.split(" ", 1)
+        return cls(stream_name, int(token))
+
+    def to_line(self):
+        return " ".join((self.stream_name, str(self.token),))
+
+
+class ErrorCommand(Command):
+    """Sent by either side if there was an ERROR. The data is a string describing
+    the error.
+    """
+    NAME = "ERROR"
+
+
+class PingCommand(Command):
+    """Sent by either side as a keep alive. The data is arbitary (often timestamp)
+    """
+    NAME = "PING"
+
+
+class NameCommand(Command):
+    """Sent by client to inform the server of the client's identity. The data
+    is the name
+    """
+    NAME = "NAME"
+
+
+class ReplicateCommand(Command):
+    """Sent by the client to subscribe to the stream.
+
+    Format::
+
+        REPLICATE <stream_name> <token>
+
+    Where <token> may be either:
+        * a numeric stream_id to stream updates from
+        * "NOW" to stream all subsequent updates.
+
+    The <stream_name> can be "ALL" to subscribe to all known streams, in which
+    case the <token> must be set to "NOW", i.e.::
+
+        REPLICATE ALL NOW
+    """
+    NAME = "REPLICATE"
+
+    def __init__(self, stream_name, token):
+        self.stream_name = stream_name
+        self.token = token
+
+    @classmethod
+    def from_line(cls, line):
+        stream_name, token = line.split(" ", 1)
+        if token in ("NOW", "now"):
+            token = "NOW"
+        else:
+            token = int(token)
+        return cls(stream_name, token)
+
+    def to_line(self):
+        return " ".join((self.stream_name, str(self.token),))
+
+
+class UserSyncCommand(Command):
+    """Sent by the client to inform the server that a user has started or
+    stopped syncing. Used to calculate presence on the master.
+
+    Includes a timestamp of when the last user sync was.
+
+    Format::
+
+        USER_SYNC <user_id> <state> <last_sync_ms>
+
+    Where <state> is either "start" or "stop"
+    """
+    NAME = "USER_SYNC"
+
+    def __init__(self, user_id, is_syncing, last_sync_ms):
+        self.user_id = user_id
+        self.is_syncing = is_syncing
+        self.last_sync_ms = last_sync_ms
+
+    @classmethod
+    def from_line(cls, line):
+        user_id, state, last_sync_ms = line.split(" ", 2)
+
+        if state not in ("start", "end"):
+            raise Exception("Invalid USER_SYNC state %r" % (state,))
+
+        return cls(user_id, state == "start", int(last_sync_ms))
+
+    def to_line(self):
+        return " ".join((
+            self.user_id, "start" if self.is_syncing else "end", str(self.last_sync_ms),
+        ))
+
+
+class FederationAckCommand(Command):
+    """Sent by the client when it has processed up to a given point in the
+    federation stream. This allows the master to drop in-memory caches of the
+    federation stream.
+
+    This must only be sent from one worker (i.e. the one sending federation)
+
+    Format::
+
+        FEDERATION_ACK <token>
+    """
+    NAME = "FEDERATION_ACK"
+
+    def __init__(self, token):
+        self.token = token
+
+    @classmethod
+    def from_line(cls, line):
+        return cls(int(line))
+
+    def to_line(self):
+        return str(self.token)
+
+
+class SyncCommand(Command):
+    """Used for testing. The client protocol implementation allows waiting
+    on a SYNC command with a specified data.
+    """
+    NAME = "SYNC"
+
+
+class RemovePusherCommand(Command):
+    """Sent by the client to request the master remove the given pusher.
+
+    Format::
+
+        REMOVE_PUSHER <app_id> <push_key> <user_id>
+    """
+    NAME = "REMOVE_PUSHER"
+
+    def __init__(self, app_id, push_key, user_id):
+        self.user_id = user_id
+        self.app_id = app_id
+        self.push_key = push_key
+
+    @classmethod
+    def from_line(cls, line):
+        app_id, push_key, user_id = line.split(" ", 2)
+
+        return cls(app_id, push_key, user_id)
+
+    def to_line(self):
+        return " ".join((self.app_id, self.push_key, self.user_id))
+
+
+class InvalidateCacheCommand(Command):
+    """Sent by the client to invalidate an upstream cache.
+
+    THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE
+    NOT DISASTROUS IF WE DROP ON THE FLOOR.
+
+    Mainly used to invalidate destination retry timing caches.
+
+    Format::
+
+        INVALIDATE_CACHE <cache_func> <keys_json>
+
+    Where <keys_json> is a json list.
+    """
+    NAME = "INVALIDATE_CACHE"
+
+    def __init__(self, cache_func, keys):
+        self.cache_func = cache_func
+        self.keys = keys
+
+    @classmethod
+    def from_line(cls, line):
+        cache_func, keys_json = line.split(" ", 1)
+
+        return cls(cache_func, simplejson.loads(keys_json))
+
+    def to_line(self):
+        return " ".join((
+            self.cache_func, _json_encoder.encode(self.keys),
+        ))
+
+
+class UserIpCommand(Command):
+    """Sent periodically when a worker sees activity from a client.
+
+    Format::
+
+        USER_IP <user_id>, <access_token>, <ip>, <device_id>, <last_seen>, <user_agent>
+    """
+    NAME = "USER_IP"
+
+    def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+        self.user_id = user_id
+        self.access_token = access_token
+        self.ip = ip
+        self.user_agent = user_agent
+        self.device_id = device_id
+        self.last_seen = last_seen
+
+    @classmethod
+    def from_line(cls, line):
+        user_id, jsn = line.split(" ", 1)
+
+        access_token, ip, user_agent, device_id, last_seen = simplejson.loads(jsn)
+
+        return cls(
+            user_id, access_token, ip, user_agent, device_id, last_seen
+        )
+
+    def to_line(self):
+        return self.user_id + " " + _json_encoder.encode((
+            self.access_token, self.ip, self.user_agent, self.device_id,
+            self.last_seen,
+        ))
+
+
+# Map of command name to command type.
+COMMAND_MAP = {
+    cmd.NAME: cmd
+    for cmd in (
+        ServerCommand,
+        RdataCommand,
+        PositionCommand,
+        ErrorCommand,
+        PingCommand,
+        NameCommand,
+        ReplicateCommand,
+        UserSyncCommand,
+        FederationAckCommand,
+        SyncCommand,
+        RemovePusherCommand,
+        InvalidateCacheCommand,
+        UserIpCommand,
+    )
+}
+
+# The commands the server is allowed to send
+VALID_SERVER_COMMANDS = (
+    ServerCommand.NAME,
+    RdataCommand.NAME,
+    PositionCommand.NAME,
+    ErrorCommand.NAME,
+    PingCommand.NAME,
+    SyncCommand.NAME,
+)
+
+# The commands the client is allowed to send
+VALID_CLIENT_COMMANDS = (
+    NameCommand.NAME,
+    ReplicateCommand.NAME,
+    PingCommand.NAME,
+    UserSyncCommand.NAME,
+    FederationAckCommand.NAME,
+    RemovePusherCommand.NAME,
+    InvalidateCacheCommand.NAME,
+    UserIpCommand.NAME,
+    ErrorCommand.NAME,
+)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
new file mode 100644
index 0000000000..d7d38464b2
--- /dev/null
+++ b/synapse/replication/tcp/protocol.py
@@ -0,0 +1,655 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""This module contains the implementation of both the client and server
+protocols.
+
+The basic structure of the protocol is line based, where the initial word of
+each line specifies the command. The rest of the line is parsed based on the
+command. For example, the `RDATA` command is defined as::
+
+    RDATA <stream_name> <token> <row_json>
+
+(Note that `<row_json>` may contains spaces, but cannot contain newlines.)
+
+Blank lines are ignored.
+
+# Example
+
+An example iteraction is shown below. Each line is prefixed with '>' or '<' to
+indicate which side is sending, these are *not* included on the wire::
+
+    * connection established *
+    > SERVER localhost:8823
+    > PING 1490197665618
+    < NAME synapse.app.appservice
+    < PING 1490197665618
+    < REPLICATE events 1
+    < REPLICATE backfill 1
+    < REPLICATE caches 1
+    > POSITION events 1
+    > POSITION backfill 1
+    > POSITION caches 1
+    > RDATA caches 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513]
+    > RDATA events 14 ["$149019767112vOHxz:localhost:8823",
+        "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null]
+    < PING 1490197675618
+    > ERROR server stopping
+    * connection closed by server *
+"""
+
+from twisted.internet import defer
+from twisted.protocols.basic import LineOnlyReceiver
+from twisted.python.failure import Failure
+
+from .commands import (
+    COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS,
+    ErrorCommand, ServerCommand, RdataCommand, PositionCommand, PingCommand,
+    NameCommand, ReplicateCommand, UserSyncCommand, SyncCommand,
+)
+from .streams import STREAMS_MAP
+
+from synapse.util.stringutils import random_string
+from synapse.metrics.metric import CounterMetric
+
+import logging
+import synapse.metrics
+import struct
+import fcntl
+
+
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+connection_close_counter = metrics.register_counter(
+    "close_reason", labels=["reason_type"],
+)
+
+
+# A list of all connected protocols. This allows us to send metrics about the
+# connections.
+connected_connections = []
+
+
+logger = logging.getLogger(__name__)
+
+
+PING_TIME = 5000
+PING_TIMEOUT_MULTIPLIER = 5
+PING_TIMEOUT_MS = PING_TIME * PING_TIMEOUT_MULTIPLIER
+
+
+class ConnectionStates(object):
+    CONNECTING = "connecting"
+    ESTABLISHED = "established"
+    PAUSED = "paused"
+    CLOSED = "closed"
+
+
+class BaseReplicationStreamProtocol(LineOnlyReceiver):
+    """Base replication protocol shared between client and server.
+
+    Reads lines (ignoring blank ones) and parses them into command classes,
+    asserting that they are valid for the given direction, i.e. server commands
+    are only sent by the server.
+
+    On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
+    command.
+
+    It also sends `PING` periodically, and correctly times out remote connections
+    (if they send a `PING` command)
+    """
+    delimiter = b'\n'
+
+    VALID_INBOUND_COMMANDS = []  # Valid commands we expect to receive
+    VALID_OUTBOUND_COMMANDS = []  # Valid commans we can send
+
+    max_line_buffer = 10000
+
+    def __init__(self, clock):
+        self.clock = clock
+
+        self.last_received_command = self.clock.time_msec()
+        self.last_sent_command = 0
+        self.time_we_closed = None  # When we requested the connection be closed
+
+        self.received_ping = False  # Have we reecived a ping from the other side
+
+        self.state = ConnectionStates.CONNECTING
+
+        self.name = "anon"  # The name sent by a client.
+        self.conn_id = random_string(5)  # To dedupe in case of name clashes.
+
+        # List of pending commands to send once we've established the connection
+        self.pending_commands = []
+
+        # The LoopingCall for sending pings.
+        self._send_ping_loop = None
+
+        self.inbound_commands_counter = CounterMetric(
+            "inbound_commands", labels=["command"],
+        )
+        self.outbound_commands_counter = CounterMetric(
+            "outbound_commands", labels=["command"],
+        )
+
+    def connectionMade(self):
+        logger.info("[%s] Connection established", self.id())
+
+        self.state = ConnectionStates.ESTABLISHED
+
+        connected_connections.append(self)  # Register connection for metrics
+
+        self.transport.registerProducer(self, True)  # For the *Producing callbacks
+
+        self._send_pending_commands()
+
+        # Starts sending pings
+        self._send_ping_loop = self.clock.looping_call(self.send_ping, 5000)
+
+        # Always send the initial PING so that the other side knows that they
+        # can time us out.
+        self.send_command(PingCommand(self.clock.time_msec()))
+
+    def send_ping(self):
+        """Periodically sends a ping and checks if we should close the connection
+        due to the other side timing out.
+        """
+        now = self.clock.time_msec()
+
+        if self.time_we_closed:
+            if now - self.time_we_closed > PING_TIMEOUT_MS:
+                logger.info(
+                    "[%s] Failed to close connection gracefully, aborting", self.id()
+                )
+                self.transport.abortConnection()
+        else:
+            if now - self.last_sent_command >= PING_TIME:
+                self.send_command(PingCommand(now))
+
+            if self.received_ping and now - self.last_received_command > PING_TIMEOUT_MS:
+                logger.info(
+                    "[%s] Connection hasn't received command in %r ms. Closing.",
+                    self.id(), now - self.last_received_command
+                )
+                self.send_error("ping timeout")
+
+    def lineReceived(self, line):
+        """Called when we've received a line
+        """
+        if line.strip() == "":
+            # Ignore blank lines
+            return
+
+        line = line.decode("utf-8")
+        cmd_name, rest_of_line = line.split(" ", 1)
+
+        if cmd_name not in self.VALID_INBOUND_COMMANDS:
+            logger.error("[%s] invalid command %s", self.id(), cmd_name)
+            self.send_error("invalid command: %s", cmd_name)
+            return
+
+        self.last_received_command = self.clock.time_msec()
+
+        self.inbound_commands_counter.inc(cmd_name)
+
+        cmd_cls = COMMAND_MAP[cmd_name]
+        try:
+            cmd = cmd_cls.from_line(rest_of_line)
+        except Exception as e:
+            logger.exception(
+                "[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line
+            )
+            self.send_error(
+                "failed to parse line for  %r: %r (%r):" % (cmd_name, e, rest_of_line)
+            )
+            return
+
+        # Now lets try and call on_<CMD_NAME> function
+        try:
+            getattr(self, "on_%s" % (cmd_name,))(cmd)
+        except Exception:
+            logger.exception("[%s] Failed to handle line: %r", self.id(), line)
+
+    def close(self):
+        logger.warn("[%s] Closing connection", self.id())
+        self.time_we_closed = self.clock.time_msec()
+        self.transport.loseConnection()
+        self.on_connection_closed()
+
+    def send_error(self, error_string, *args):
+        """Send an error to remote and close the connection.
+        """
+        self.send_command(ErrorCommand(error_string % args))
+        self.close()
+
+    def send_command(self, cmd, do_buffer=True):
+        """Send a command if connection has been established.
+
+        Args:
+            cmd (Command)
+            do_buffer (bool): Whether to buffer the message or always attempt
+                to send the command. This is mostly used to send an error
+                message if we're about to close the connection due our buffers
+                becoming full.
+        """
+        if self.state == ConnectionStates.CLOSED:
+            logger.debug("[%s] Not sending, connection closed", self.id())
+            return
+
+        if do_buffer and self.state != ConnectionStates.ESTABLISHED:
+            self._queue_command(cmd)
+            return
+
+        self.outbound_commands_counter.inc(cmd.NAME)
+
+        string = "%s %s" % (cmd.NAME, cmd.to_line(),)
+        if "\n" in string:
+            raise Exception("Unexpected newline in command: %r", string)
+
+        self.sendLine(string.encode("utf-8"))
+
+        self.last_sent_command = self.clock.time_msec()
+
+    def _queue_command(self, cmd):
+        """Queue the command until the connection is ready to write to again.
+        """
+        logger.debug("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd)
+        self.pending_commands.append(cmd)
+
+        if len(self.pending_commands) > self.max_line_buffer:
+            # The other side is failing to keep up and out buffers are becoming
+            # full, so lets close the connection.
+            # XXX: should we squawk more loudly?
+            logger.error("[%s] Remote failed to keep up", self.id())
+            self.send_command(ErrorCommand("Failed to keep up"), do_buffer=False)
+            self.close()
+
+    def _send_pending_commands(self):
+        """Send any queued commandes
+        """
+        pending = self.pending_commands
+        self.pending_commands = []
+        for cmd in pending:
+            self.send_command(cmd)
+
+    def on_PING(self, line):
+        self.received_ping = True
+
+    def on_ERROR(self, cmd):
+        logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
+
+    def pauseProducing(self):
+        """This is called when both the kernel send buffer and the twisted
+        tcp connection send buffers have become full.
+
+        We don't actually have any control over those sizes, so we buffer some
+        commands ourselves before knifing the connection due to the remote
+        failing to keep up.
+        """
+        logger.info("[%s] Pause producing", self.id())
+        self.state = ConnectionStates.PAUSED
+
+    def resumeProducing(self):
+        """The remote has caught up after we started buffering!
+        """
+        logger.info("[%s] Resume producing", self.id())
+        self.state = ConnectionStates.ESTABLISHED
+        self._send_pending_commands()
+
+    def stopProducing(self):
+        """We're never going to send any more data (normally because either
+        we or the remote has closed the connection)
+        """
+        logger.info("[%s] Stop producing", self.id())
+        self.on_connection_closed()
+
+    def connectionLost(self, reason):
+        logger.info("[%s] Replication connection closed: %r", self.id(), reason)
+        if isinstance(reason, Failure):
+            connection_close_counter.inc(reason.type.__name__)
+        else:
+            connection_close_counter.inc(reason.__class__.__name__)
+
+        try:
+            # Remove us from list of connections to be monitored
+            connected_connections.remove(self)
+        except ValueError:
+            pass
+
+        # Stop the looping call sending pings.
+        if self._send_ping_loop and self._send_ping_loop.running:
+            self._send_ping_loop.stop()
+
+        self.on_connection_closed()
+
+    def on_connection_closed(self):
+        logger.info("[%s] Connection was closed", self.id())
+
+        self.state = ConnectionStates.CLOSED
+        self.pending_commands = []
+
+        if self.transport:
+            self.transport.unregisterProducer()
+
+    def __str__(self):
+        return "ReplicationConnection<name=%s,conn_id=%s,addr=%s>" % (
+            self.name, self.conn_id, self.addr,
+        )
+
+    def id(self):
+        return "%s-%s" % (self.name, self.conn_id)
+
+
+class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
+    VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS
+    VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS
+
+    def __init__(self, server_name, clock, streamer, addr):
+        BaseReplicationStreamProtocol.__init__(self, clock)  # Old style class
+
+        self.server_name = server_name
+        self.streamer = streamer
+        self.addr = addr
+
+        # The streams the client has subscribed to and is up to date with
+        self.replication_streams = set()
+
+        # The streams the client is currently subscribing to.
+        self.connecting_streams = set()
+
+        # Map from stream name to list of updates to send once we've finished
+        # subscribing the client to the stream.
+        self.pending_rdata = {}
+
+    def connectionMade(self):
+        self.send_command(ServerCommand(self.server_name))
+        BaseReplicationStreamProtocol.connectionMade(self)
+        self.streamer.new_connection(self)
+
+    def on_NAME(self, cmd):
+        logger.info("[%s] Renamed to %r", self.id(), cmd.data)
+        self.name = cmd.data
+
+    def on_USER_SYNC(self, cmd):
+        self.streamer.on_user_sync(
+            self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms,
+        )
+
+    def on_REPLICATE(self, cmd):
+        stream_name = cmd.stream_name
+        token = cmd.token
+
+        if stream_name == "ALL":
+            # Subscribe to all streams we're publishing to.
+            for stream in self.streamer.streams_by_name.iterkeys():
+                self.subscribe_to_stream(stream, token)
+        else:
+            self.subscribe_to_stream(stream_name, token)
+
+    def on_FEDERATION_ACK(self, cmd):
+        self.streamer.federation_ack(cmd.token)
+
+    def on_REMOVE_PUSHER(self, cmd):
+        self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
+
+    def on_INVALIDATE_CACHE(self, cmd):
+        self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
+
+    def on_USER_IP(self, cmd):
+        self.streamer.on_user_ip(
+            cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
+            cmd.last_seen,
+        )
+
+    @defer.inlineCallbacks
+    def subscribe_to_stream(self, stream_name, token):
+        """Subscribe the remote to a streams.
+
+        This invloves checking if they've missed anything and sending those
+        updates down if they have. During that time new updates for the stream
+        are queued and sent once we've sent down any missed updates.
+        """
+        self.replication_streams.discard(stream_name)
+        self.connecting_streams.add(stream_name)
+
+        try:
+            # Get missing updates
+            updates, current_token = yield self.streamer.get_stream_updates(
+                stream_name, token,
+            )
+
+            # Send all the missing updates
+            for update in updates:
+                token, row = update[0], update[1]
+                self.send_command(RdataCommand(stream_name, token, row))
+
+            # We send a POSITION command to ensure that they have an up to
+            # date token (especially useful if we didn't send any updates
+            # above)
+            self.send_command(PositionCommand(stream_name, current_token))
+
+            # Now we can send any updates that came in while we were subscribing
+            pending_rdata = self.pending_rdata.pop(stream_name, [])
+            for token, update in pending_rdata:
+                # Only send updates newer than the current token
+                if token > current_token:
+                    self.send_command(RdataCommand(stream_name, token, update))
+
+            # They're now fully subscribed
+            self.replication_streams.add(stream_name)
+        except Exception as e:
+            logger.exception("[%s] Failed to handle REPLICATE command", self.id())
+            self.send_error("failed to handle replicate: %r", e)
+        finally:
+            self.connecting_streams.discard(stream_name)
+
+    def stream_update(self, stream_name, token, data):
+        """Called when a new update is available to stream to clients.
+
+        We need to check if the client is interested in the stream or not
+        """
+        if stream_name in self.replication_streams:
+            # The client is subscribed to the stream
+            self.send_command(RdataCommand(stream_name, token, data))
+        elif stream_name in self.connecting_streams:
+            # The client is being subscribed to the stream
+            logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
+            self.pending_rdata.setdefault(stream_name, []).append((token, data))
+        else:
+            # The client isn't subscribed
+            logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
+
+    def send_sync(self, data):
+        self.send_command(SyncCommand(data))
+
+    def on_connection_closed(self):
+        BaseReplicationStreamProtocol.on_connection_closed(self)
+        self.streamer.lost_connection(self)
+
+
+class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
+    VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
+    VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
+
+    def __init__(self, client_name, server_name, clock, handler):
+        BaseReplicationStreamProtocol.__init__(self, clock)
+
+        self.client_name = client_name
+        self.server_name = server_name
+        self.handler = handler
+
+        # Map of stream to batched updates. See RdataCommand for info on how
+        # batching works.
+        self.pending_batches = {}
+
+    def connectionMade(self):
+        self.send_command(NameCommand(self.client_name))
+        BaseReplicationStreamProtocol.connectionMade(self)
+
+        # Once we've connected subscribe to the necessary streams
+        for stream_name, token in self.handler.get_streams_to_replicate().iteritems():
+            self.replicate(stream_name, token)
+
+        # Tell the server if we have any users currently syncing (should only
+        # happen on synchrotrons)
+        currently_syncing = self.handler.get_currently_syncing_users()
+        now = self.clock.time_msec()
+        for user_id in currently_syncing:
+            self.send_command(UserSyncCommand(user_id, True, now))
+
+        # We've now finished connecting to so inform the client handler
+        self.handler.update_connection(self)
+
+    def on_SERVER(self, cmd):
+        if cmd.data != self.server_name:
+            logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
+            self.send_error("Wrong remote")
+
+    def on_RDATA(self, cmd):
+        stream_name = cmd.stream_name
+        inbound_rdata_count.inc(stream_name)
+
+        try:
+            row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row)
+        except Exception:
+            logger.exception(
+                "[%s] Failed to parse RDATA: %r %r",
+                self.id(), stream_name, cmd.row
+            )
+            raise
+
+        if cmd.token is None:
+            # I.e. this is part of a batch of updates for this stream. Batch
+            # until we get an update for the stream with a non None token
+            self.pending_batches.setdefault(stream_name, []).append(row)
+        else:
+            # Check if this is the last of a batch of updates
+            rows = self.pending_batches.pop(stream_name, [])
+            rows.append(row)
+
+            self.handler.on_rdata(stream_name, cmd.token, rows)
+
+    def on_POSITION(self, cmd):
+        self.handler.on_position(cmd.stream_name, cmd.token)
+
+    def on_SYNC(self, cmd):
+        self.handler.on_sync(cmd.data)
+
+    def replicate(self, stream_name, token):
+        """Send the subscription request to the server
+        """
+        if stream_name not in STREAMS_MAP:
+            raise Exception("Invalid stream name %r" % (stream_name,))
+
+        logger.info(
+            "[%s] Subscribing to replication stream: %r from %r",
+            self.id(), stream_name, token
+        )
+
+        self.send_command(ReplicateCommand(stream_name, token))
+
+    def on_connection_closed(self):
+        BaseReplicationStreamProtocol.on_connection_closed(self)
+        self.handler.update_connection(None)
+
+
+# The following simply registers metrics for the replication connections
+
+metrics.register_callback(
+    "pending_commands",
+    lambda: {
+        (p.name, p.conn_id): len(p.pending_commands)
+        for p in connected_connections
+    },
+    labels=["name", "conn_id"],
+)
+
+
+def transport_buffer_size(protocol):
+    if protocol.transport:
+        size = len(protocol.transport.dataBuffer) + protocol.transport._tempDataLen
+        return size
+    return 0
+
+
+metrics.register_callback(
+    "transport_send_buffer",
+    lambda: {
+        (p.name, p.conn_id): transport_buffer_size(p)
+        for p in connected_connections
+    },
+    labels=["name", "conn_id"],
+)
+
+
+def transport_kernel_read_buffer_size(protocol, read=True):
+    SIOCINQ = 0x541B
+    SIOCOUTQ = 0x5411
+
+    if protocol.transport:
+        fileno = protocol.transport.getHandle().fileno()
+        if read:
+            op = SIOCINQ
+        else:
+            op = SIOCOUTQ
+        size = struct.unpack("I", fcntl.ioctl(fileno, op, '\0\0\0\0'))[0]
+        return size
+    return 0
+
+
+metrics.register_callback(
+    "transport_kernel_send_buffer",
+    lambda: {
+        (p.name, p.conn_id): transport_kernel_read_buffer_size(p, False)
+        for p in connected_connections
+    },
+    labels=["name", "conn_id"],
+)
+
+
+metrics.register_callback(
+    "transport_kernel_read_buffer",
+    lambda: {
+        (p.name, p.conn_id): transport_kernel_read_buffer_size(p, True)
+        for p in connected_connections
+    },
+    labels=["name", "conn_id"],
+)
+
+
+metrics.register_callback(
+    "inbound_commands",
+    lambda: {
+        (k[0], p.name, p.conn_id): count
+        for p in connected_connections
+        for k, count in p.inbound_commands_counter.counts.iteritems()
+    },
+    labels=["command", "name", "conn_id"],
+)
+
+metrics.register_callback(
+    "outbound_commands",
+    lambda: {
+        (k[0], p.name, p.conn_id): count
+        for p in connected_connections
+        for k, count in p.outbound_commands_counter.counts.iteritems()
+    },
+    labels=["command", "name", "conn_id"],
+)
+
+# number of updates received for each RDATA stream
+inbound_rdata_count = metrics.register_counter(
+    "inbound_rdata_count",
+    labels=["stream_name"],
+)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
new file mode 100644
index 0000000000..a41af4fd6c
--- /dev/null
+++ b/synapse/replication/tcp/resource.py
@@ -0,0 +1,307 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""The server side of the replication stream.
+"""
+
+from twisted.internet import defer, reactor
+from twisted.internet.protocol import Factory
+
+from .streams import STREAMS_MAP, FederationStream
+from .protocol import ServerReplicationStreamProtocol
+
+from synapse.util.metrics import Measure, measure_func
+
+import logging
+import synapse.metrics
+
+
+metrics = synapse.metrics.get_metrics_for(__name__)
+stream_updates_counter = metrics.register_counter(
+    "stream_updates", labels=["stream_name"]
+)
+user_sync_counter = metrics.register_counter("user_sync")
+federation_ack_counter = metrics.register_counter("federation_ack")
+remove_pusher_counter = metrics.register_counter("remove_pusher")
+invalidate_cache_counter = metrics.register_counter("invalidate_cache")
+user_ip_cache_counter = metrics.register_counter("user_ip_cache")
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationStreamProtocolFactory(Factory):
+    """Factory for new replication connections.
+    """
+    def __init__(self, hs):
+        self.streamer = ReplicationStreamer(hs)
+        self.clock = hs.get_clock()
+        self.server_name = hs.config.server_name
+
+    def buildProtocol(self, addr):
+        return ServerReplicationStreamProtocol(
+            self.server_name,
+            self.clock,
+            self.streamer,
+            addr
+        )
+
+
+class ReplicationStreamer(object):
+    """Handles replication connections.
+
+    This needs to be poked when new replication data may be available. When new
+    data is available it will propagate to all connected clients.
+    """
+
+    def __init__(self, hs):
+        self.store = hs.get_datastore()
+        self.presence_handler = hs.get_presence_handler()
+        self.clock = hs.get_clock()
+        self.notifier = hs.get_notifier()
+
+        # Current connections.
+        self.connections = []
+
+        metrics.register_callback("total_connections", lambda: len(self.connections))
+
+        # List of streams that clients can subscribe to.
+        # We only support federation stream if federation sending hase been
+        # disabled on the master.
+        self.streams = [
+            stream(hs) for stream in STREAMS_MAP.itervalues()
+            if stream != FederationStream or not hs.config.send_federation
+        ]
+
+        self.streams_by_name = {stream.NAME: stream for stream in self.streams}
+
+        metrics.register_callback(
+            "connections_per_stream",
+            lambda: {
+                (stream_name,): len([
+                    conn for conn in self.connections
+                    if stream_name in conn.replication_streams
+                ])
+                for stream_name in self.streams_by_name
+            },
+            labels=["stream_name"],
+        )
+
+        self.federation_sender = None
+        if not hs.config.send_federation:
+            self.federation_sender = hs.get_federation_sender()
+
+        self.notifier.add_replication_callback(self.on_notifier_poke)
+
+        # Keeps track of whether we are currently checking for updates
+        self.is_looping = False
+        self.pending_updates = False
+
+        reactor.addSystemEventTrigger("before", "shutdown", self.on_shutdown)
+
+    def on_shutdown(self):
+        # close all connections on shutdown
+        for conn in self.connections:
+            conn.send_error("server shutting down")
+
+    @defer.inlineCallbacks
+    def on_notifier_poke(self):
+        """Checks if there is actually any new data and sends it to the
+        connections if there are.
+
+        This should get called each time new data is available, even if it
+        is currently being executed, so that nothing gets missed
+        """
+        if not self.connections:
+            # Don't bother if nothing is listening. We still need to advance
+            # the stream tokens otherwise they'll fall beihind forever
+            for stream in self.streams:
+                stream.discard_updates_and_advance()
+            return
+
+        # If we're in the process of checking for new updates, mark that fact
+        # and return
+        if self.is_looping:
+            logger.debug("Noitifier poke loop already running")
+            self.pending_updates = True
+            return
+
+        self.pending_updates = True
+        self.is_looping = True
+
+        try:
+            # Keep looping while there have been pokes about potential updates.
+            # This protects against the race where a stream we already checked
+            # gets an update while we're handling other streams.
+            while self.pending_updates:
+                self.pending_updates = False
+
+                with Measure(self.clock, "repl.stream.get_updates"):
+                    # First we tell the streams that they should update their
+                    # current tokens.
+                    for stream in self.streams:
+                        stream.advance_current_token()
+
+                    for stream in self.streams:
+                        if stream.last_token == stream.upto_token:
+                            continue
+
+                        logger.debug(
+                            "Getting stream: %s: %s -> %s",
+                            stream.NAME, stream.last_token, stream.upto_token
+                        )
+                        try:
+                            updates, current_token = yield stream.get_updates()
+                        except Exception:
+                            logger.info("Failed to handle stream %s", stream.NAME)
+                            raise
+
+                        logger.debug(
+                            "Sending %d updates to %d connections",
+                            len(updates), len(self.connections),
+                        )
+
+                        if updates:
+                            logger.info(
+                                "Streaming: %s -> %s", stream.NAME, updates[-1][0]
+                            )
+                            stream_updates_counter.inc_by(len(updates), stream.NAME)
+
+                        # Some streams return multiple rows with the same stream IDs,
+                        # we need to make sure they get sent out in batches. We do
+                        # this by setting the current token to all but the last of
+                        # a series of updates with the same token to have a None
+                        # token. See RdataCommand for more details.
+                        batched_updates = _batch_updates(updates)
+
+                        for conn in self.connections:
+                            for token, row in batched_updates:
+                                try:
+                                    conn.stream_update(stream.NAME, token, row)
+                                except Exception:
+                                    logger.exception("Failed to replicate")
+
+            logger.debug("No more pending updates, breaking poke loop")
+        finally:
+            self.pending_updates = False
+            self.is_looping = False
+
+    @measure_func("repl.get_stream_updates")
+    def get_stream_updates(self, stream_name, token):
+        """For a given stream get all updates since token. This is called when
+        a client first subscribes to a stream.
+        """
+        stream = self.streams_by_name.get(stream_name, None)
+        if not stream:
+            raise Exception("unknown stream %s", stream_name)
+
+        return stream.get_updates_since(token)
+
+    @measure_func("repl.federation_ack")
+    def federation_ack(self, token):
+        """We've received an ack for federation stream from a client.
+        """
+        federation_ack_counter.inc()
+        if self.federation_sender:
+            self.federation_sender.federation_ack(token)
+
+    @measure_func("repl.on_user_sync")
+    @defer.inlineCallbacks
+    def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
+        """A client has started/stopped syncing on a worker.
+        """
+        user_sync_counter.inc()
+        yield self.presence_handler.update_external_syncs_row(
+            conn_id, user_id, is_syncing, last_sync_ms,
+        )
+
+    @measure_func("repl.on_remove_pusher")
+    @defer.inlineCallbacks
+    def on_remove_pusher(self, app_id, push_key, user_id):
+        """A client has asked us to remove a pusher
+        """
+        remove_pusher_counter.inc()
+        yield self.store.delete_pusher_by_app_id_pushkey_user_id(
+            app_id=app_id, pushkey=push_key, user_id=user_id
+        )
+
+        self.notifier.on_new_replication_data()
+
+    @measure_func("repl.on_invalidate_cache")
+    def on_invalidate_cache(self, cache_func, keys):
+        """The client has asked us to invalidate a cache
+        """
+        invalidate_cache_counter.inc()
+        getattr(self.store, cache_func).invalidate(tuple(keys))
+
+    @measure_func("repl.on_user_ip")
+    @defer.inlineCallbacks
+    def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+        """The client saw a user request
+        """
+        user_ip_cache_counter.inc()
+        yield self.store.insert_client_ip(
+            user_id, access_token, ip, user_agent, device_id, last_seen,
+        )
+
+    def send_sync_to_all_connections(self, data):
+        """Sends a SYNC command to all clients.
+
+        Used in tests.
+        """
+        for conn in self.connections:
+            conn.send_sync(data)
+
+    def new_connection(self, connection):
+        """A new client connection has been established
+        """
+        self.connections.append(connection)
+
+    def lost_connection(self, connection):
+        """A client connection has been lost
+        """
+        try:
+            self.connections.remove(connection)
+        except ValueError:
+            pass
+
+        # We need to tell the presence handler that the connection has been
+        # lost so that it can handle any ongoing syncs on that connection.
+        self.presence_handler.update_external_syncs_clear(connection.conn_id)
+
+
+def _batch_updates(updates):
+    """Takes a list of updates of form [(token, row)] and sets the token to
+    None for all rows where the next row has the same token. This is used to
+    implement batching.
+
+    For example:
+
+        [(1, _), (1, _), (2, _), (3, _), (3, _)]
+
+    becomes:
+
+        [(None, _), (1, _), (2, _), (None, _), (3, _)]
+    """
+    if not updates:
+        return []
+
+    new_updates = []
+    for i, update in enumerate(updates[:-1]):
+        if update[0] == updates[i + 1][0]:
+            new_updates.append((None, update[1]))
+        else:
+            new_updates.append(update)
+
+    new_updates.append(updates[-1])
+    return new_updates
diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py
new file mode 100644
index 0000000000..4c60bf79f9
--- /dev/null
+++ b/synapse/replication/tcp/streams.py
@@ -0,0 +1,506 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Defines all the valid streams that clients can subscribe to, and the format
+of the rows returned by each stream.
+
+Each stream is defined by the following information:
+
+    stream name:        The name of the stream
+    row type:           The type that is used to serialise/deserialse the row
+    current_token:      The function that returns the current token for the stream
+    update_function:    The function that returns a list of updates between two tokens
+"""
+
+from twisted.internet import defer
+from collections import namedtuple
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+MAX_EVENTS_BEHIND = 10000
+
+
+EventStreamRow = namedtuple("EventStreamRow", (
+    "event_id",  # str
+    "room_id",  # str
+    "type",  # str
+    "state_key",  # str, optional
+    "redacts",  # str, optional
+))
+BackfillStreamRow = namedtuple("BackfillStreamRow", (
+    "event_id",  # str
+    "room_id",  # str
+    "type",  # str
+    "state_key",  # str, optional
+    "redacts",  # str, optional
+))
+PresenceStreamRow = namedtuple("PresenceStreamRow", (
+    "user_id",  # str
+    "state",  # str
+    "last_active_ts",  # int
+    "last_federation_update_ts",  # int
+    "last_user_sync_ts",  # int
+    "status_msg",   # str
+    "currently_active",  # bool
+))
+TypingStreamRow = namedtuple("TypingStreamRow", (
+    "room_id",  # str
+    "user_ids",  # list(str)
+))
+ReceiptsStreamRow = namedtuple("ReceiptsStreamRow", (
+    "room_id",  # str
+    "receipt_type",  # str
+    "user_id",  # str
+    "event_id",  # str
+    "data",  # dict
+))
+PushRulesStreamRow = namedtuple("PushRulesStreamRow", (
+    "user_id",  # str
+))
+PushersStreamRow = namedtuple("PushersStreamRow", (
+    "user_id",  # str
+    "app_id",  # str
+    "pushkey",  # str
+    "deleted",  # bool
+))
+CachesStreamRow = namedtuple("CachesStreamRow", (
+    "cache_func",  # str
+    "keys",  # list(str)
+    "invalidation_ts",  # int
+))
+PublicRoomsStreamRow = namedtuple("PublicRoomsStreamRow", (
+    "room_id",  # str
+    "visibility",  # str
+    "appservice_id",  # str, optional
+    "network_id",  # str, optional
+))
+DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", (
+    "user_id",  # str
+    "destination",  # str
+))
+ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", (
+    "entity",  # str
+))
+FederationStreamRow = namedtuple("FederationStreamRow", (
+    "type",  # str, the type of data as defined in the BaseFederationRows
+    "data",  # dict, serialization of a federation.send_queue.BaseFederationRow
+))
+TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", (
+    "user_id",  # str
+    "room_id",  # str
+    "data",  # dict
+))
+AccountDataStreamRow = namedtuple("AccountDataStream", (
+    "user_id",  # str
+    "room_id",  # str
+    "data_type",  # str
+    "data",  # dict
+))
+CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", (
+    "room_id",  # str
+    "type",  # str
+    "state_key",  # str
+    "event_id",  # str, optional
+))
+GroupsStreamRow = namedtuple("GroupsStreamRow", (
+    "group_id",  # str
+    "user_id",  # str
+    "type",  # str
+    "content",  # dict
+))
+
+
+class Stream(object):
+    """Base class for the streams.
+
+    Provides a `get_updates()` function that returns new updates since the last
+    time it was called up until the point `advance_current_token` was called.
+    """
+    NAME = None  # The name of the stream
+    ROW_TYPE = None  # The type of the row
+    _LIMITED = True  # Whether the update function takes a limit
+
+    def __init__(self, hs):
+        # The token from which we last asked for updates
+        self.last_token = self.current_token()
+
+        # The token that we will get updates up to
+        self.upto_token = self.current_token()
+
+    def advance_current_token(self):
+        """Updates `upto_token` to "now", which updates up until which point
+        get_updates[_since] will fetch rows till.
+        """
+        self.upto_token = self.current_token()
+
+    def discard_updates_and_advance(self):
+        """Called when the stream should advance but the updates would be discarded,
+        e.g. when there are no currently connected workers.
+        """
+        self.upto_token = self.current_token()
+        self.last_token = self.upto_token
+
+    @defer.inlineCallbacks
+    def get_updates(self):
+        """Gets all updates since the last time this function was called (or
+        since the stream was constructed if it hadn't been called before),
+        until the `upto_token`
+
+        Returns:
+            (list(ROW_TYPE), int): list of updates plus the token used as an
+                upper bound of the updates (i.e. the "current token")
+        """
+        updates, current_token = yield self.get_updates_since(self.last_token)
+        self.last_token = current_token
+
+        defer.returnValue((updates, current_token))
+
+    @defer.inlineCallbacks
+    def get_updates_since(self, from_token):
+        """Like get_updates except allows specifying from when we should
+        stream updates
+
+        Returns:
+            (list(ROW_TYPE), int): list of updates plus the token used as an
+                upper bound of the updates (i.e. the "current token")
+        """
+        if from_token in ("NOW", "now"):
+            defer.returnValue(([], self.upto_token))
+
+        current_token = self.upto_token
+
+        from_token = int(from_token)
+
+        if from_token == current_token:
+            defer.returnValue(([], current_token))
+
+        if self._LIMITED:
+            rows = yield self.update_function(
+                from_token, current_token,
+                limit=MAX_EVENTS_BEHIND + 1,
+            )
+
+            if len(rows) >= MAX_EVENTS_BEHIND:
+                raise Exception("stream %s has fallen behined" % (self.NAME))
+        else:
+            rows = yield self.update_function(
+                from_token, current_token,
+            )
+
+        updates = [(row[0], self.ROW_TYPE(*row[1:])) for row in rows]
+
+        defer.returnValue((updates, current_token))
+
+    def current_token(self):
+        """Gets the current token of the underlying streams. Should be provided
+        by the sub classes
+
+        Returns:
+            int
+        """
+        raise NotImplementedError()
+
+    def update_function(self, from_token, current_token, limit=None):
+        """Get updates between from_token and to_token. If Stream._LIMITED is
+        True then limit is provided, otherwise it's not.
+
+        Returns:
+            Deferred(list(tuple)): the first entry in the tuple is the token for
+                that update, and the rest of the tuple gets used to construct
+                a ``ROW_TYPE`` instance
+        """
+        raise NotImplementedError()
+
+
+class EventsStream(Stream):
+    """We received a new event, or an event went from being an outlier to not
+    """
+    NAME = "events"
+    ROW_TYPE = EventStreamRow
+
+    def __init__(self, hs):
+        store = hs.get_datastore()
+        self.current_token = store.get_current_events_token
+        self.update_function = store.get_all_new_forward_event_rows
+
+        super(EventsStream, self).__init__(hs)
+
+
+class BackfillStream(Stream):
+    """We fetched some old events and either we had never seen that event before
+    or it went from being an outlier to not.
+    """
+    NAME = "backfill"
+    ROW_TYPE = BackfillStreamRow
+
+    def __init__(self, hs):
+        store = hs.get_datastore()
+        self.current_token = store.get_current_backfill_token
+        self.update_function = store.get_all_new_backfill_event_rows
+
+        super(BackfillStream, self).__init__(hs)
+
+
+class PresenceStream(Stream):
+    NAME = "presence"
+    _LIMITED = False
+    ROW_TYPE = PresenceStreamRow
+
+    def __init__(self, hs):
+        store = hs.get_datastore()
+        presence_handler = hs.get_presence_handler()
+
+        self.current_token = store.get_current_presence_token
+        self.update_function = presence_handler.get_all_presence_updates
+
+        super(PresenceStream, self).__init__(hs)
+
+
+class TypingStream(Stream):
+    NAME = "typing"
+    _LIMITED = False
+    ROW_TYPE = TypingStreamRow
+
+    def __init__(self, hs):
+        typing_handler = hs.get_typing_handler()
+
+        self.current_token = typing_handler.get_current_token
+        self.update_function = typing_handler.get_all_typing_updates
+
+        super(TypingStream, self).__init__(hs)
+
+
+class ReceiptsStream(Stream):
+    NAME = "receipts"
+    ROW_TYPE = ReceiptsStreamRow
+
+    def __init__(self, hs):
+        store = hs.get_datastore()
+
+        self.current_token = store.get_max_receipt_stream_id
+        self.update_function = store.get_all_updated_receipts
+
+        super(ReceiptsStream, self).__init__(hs)
+
+
+class PushRulesStream(Stream):
+    """A user has changed their push rules
+    """
+    NAME = "push_rules"
+    ROW_TYPE = PushRulesStreamRow
+
+    def __init__(self, hs):
+        self.store = hs.get_datastore()
+        super(PushRulesStream, self).__init__(hs)
+
+    def current_token(self):
+        push_rules_token, _ = self.store.get_push_rules_stream_token()
+        return push_rules_token
+
+    @defer.inlineCallbacks
+    def update_function(self, from_token, to_token, limit):
+        rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit)
+        defer.returnValue([(row[0], row[2]) for row in rows])
+
+
+class PushersStream(Stream):
+    """A user has added/changed/removed a pusher
+    """
+    NAME = "pushers"
+    ROW_TYPE = PushersStreamRow
+
+    def __init__(self, hs):
+        store = hs.get_datastore()
+
+        self.current_token = store.get_pushers_stream_token
+        self.update_function = store.get_all_updated_pushers_rows
+
+        super(PushersStream, self).__init__(hs)
+
+
+class CachesStream(Stream):
+    """A cache was invalidated on the master and no other stream would invalidate
+    the cache on the workers
+    """
+    NAME = "caches"
+    ROW_TYPE = CachesStreamRow
+
+    def __init__(self, hs):
+        store = hs.get_datastore()
+
+        self.current_token = store.get_cache_stream_token
+        self.update_function = store.get_all_updated_caches
+
+        super(CachesStream, self).__init__(hs)
+
+
+class PublicRoomsStream(Stream):
+    """The public rooms list changed
+    """
+    NAME = "public_rooms"
+    ROW_TYPE = PublicRoomsStreamRow
+
+    def __init__(self, hs):
+        store = hs.get_datastore()
+
+        self.current_token = store.get_current_public_room_stream_id
+        self.update_function = store.get_all_new_public_rooms
+
+        super(PublicRoomsStream, self).__init__(hs)
+
+
+class DeviceListsStream(Stream):
+    """Someone added/changed/removed a device
+    """
+    NAME = "device_lists"
+    _LIMITED = False
+    ROW_TYPE = DeviceListsStreamRow
+
+    def __init__(self, hs):
+        store = hs.get_datastore()
+
+        self.current_token = store.get_device_stream_token
+        self.update_function = store.get_all_device_list_changes_for_remotes
+
+        super(DeviceListsStream, self).__init__(hs)
+
+
+class ToDeviceStream(Stream):
+    """New to_device messages for a client
+    """
+    NAME = "to_device"
+    ROW_TYPE = ToDeviceStreamRow
+
+    def __init__(self, hs):
+        store = hs.get_datastore()
+
+        self.current_token = store.get_to_device_stream_token
+        self.update_function = store.get_all_new_device_messages
+
+        super(ToDeviceStream, self).__init__(hs)
+
+
+class FederationStream(Stream):
+    """Data to be sent over federation. Only available when master has federation
+    sending disabled.
+    """
+    NAME = "federation"
+    ROW_TYPE = FederationStreamRow
+
+    def __init__(self, hs):
+        federation_sender = hs.get_federation_sender()
+
+        self.current_token = federation_sender.get_current_token
+        self.update_function = federation_sender.get_replication_rows
+
+        super(FederationStream, self).__init__(hs)
+
+
+class TagAccountDataStream(Stream):
+    """Someone added/removed a tag for a room
+    """
+    NAME = "tag_account_data"
+    ROW_TYPE = TagAccountDataStreamRow
+
+    def __init__(self, hs):
+        store = hs.get_datastore()
+
+        self.current_token = store.get_max_account_data_stream_id
+        self.update_function = store.get_all_updated_tags
+
+        super(TagAccountDataStream, self).__init__(hs)
+
+
+class AccountDataStream(Stream):
+    """Global or per room account data was changed
+    """
+    NAME = "account_data"
+    ROW_TYPE = AccountDataStreamRow
+
+    def __init__(self, hs):
+        self.store = hs.get_datastore()
+
+        self.current_token = self.store.get_max_account_data_stream_id
+
+        super(AccountDataStream, self).__init__(hs)
+
+    @defer.inlineCallbacks
+    def update_function(self, from_token, to_token, limit):
+        global_results, room_results = yield self.store.get_all_updated_account_data(
+            from_token, from_token, to_token, limit
+        )
+
+        results = list(room_results)
+        results.extend(
+            (stream_id, user_id, None, account_data_type, content,)
+            for stream_id, user_id, account_data_type, content in global_results
+        )
+
+        defer.returnValue(results)
+
+
+class CurrentStateDeltaStream(Stream):
+    """Current state for a room was changed
+    """
+    NAME = "current_state_deltas"
+    ROW_TYPE = CurrentStateDeltaStreamRow
+
+    def __init__(self, hs):
+        store = hs.get_datastore()
+
+        self.current_token = store.get_max_current_state_delta_stream_id
+        self.update_function = store.get_all_updated_current_state_deltas
+
+        super(CurrentStateDeltaStream, self).__init__(hs)
+
+
+class GroupServerStream(Stream):
+    NAME = "groups"
+    ROW_TYPE = GroupsStreamRow
+
+    def __init__(self, hs):
+        store = hs.get_datastore()
+
+        self.current_token = store.get_group_stream_token
+        self.update_function = store.get_all_groups_changes
+
+        super(GroupServerStream, self).__init__(hs)
+
+
+STREAMS_MAP = {
+    stream.NAME: stream
+    for stream in (
+        EventsStream,
+        BackfillStream,
+        PresenceStream,
+        TypingStream,
+        ReceiptsStream,
+        PushRulesStream,
+        PushersStream,
+        CachesStream,
+        PublicRoomsStream,
+        DeviceListsStream,
+        ToDeviceStream,
+        FederationStream,
+        TagAccountDataStream,
+        AccountDataStream,
+        CurrentStateDeltaStream,
+        GroupServerStream,
+    )
+}
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index f9f5a3e077..16f5a73b95 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -40,6 +40,7 @@ from synapse.rest.client.v2_alpha import (
     register,
     auth,
     receipts,
+    read_marker,
     keys,
     tokenrefresh,
     tags,
@@ -50,6 +51,8 @@ from synapse.rest.client.v2_alpha import (
     devices,
     thirdparty,
     sendtodevice,
+    user_directory,
+    groups,
 )
 
 from synapse.http.server import JsonResource
@@ -88,6 +91,7 @@ class ClientRestResource(JsonResource):
         register.register_servlets(hs, client_resource)
         auth.register_servlets(hs, client_resource)
         receipts.register_servlets(hs, client_resource)
+        read_marker.register_servlets(hs, client_resource)
         keys.register_servlets(hs, client_resource)
         tokenrefresh.register_servlets(hs, client_resource)
         tags.register_servlets(hs, client_resource)
@@ -98,3 +102,5 @@ class ClientRestResource(JsonResource):
         devices.register_servlets(hs, client_resource)
         thirdparty.register_servlets(hs, client_resource)
         sendtodevice.register_servlets(hs, client_resource)
+        user_directory.register_servlets(hs, client_resource)
+        groups.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 29fcd72375..efd5c9873d 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,8 +16,9 @@
 
 from twisted.internet import defer
 
-from synapse.api.errors import AuthError, SynapseError
-from synapse.types import UserID
+from synapse.api.constants import Membership
+from synapse.api.errors import AuthError, SynapseError, Codes, NotFoundError
+from synapse.types import UserID, create_requester
 from synapse.http.servlet import parse_json_object_from_request
 
 from .base import ClientV1RestServlet, client_path_patterns
@@ -112,12 +114,18 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet):
 
 class PurgeHistoryRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns(
-        "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
+        "/admin/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
     )
 
     def __init__(self, hs):
+        """
+
+        Args:
+            hs (synapse.server.HomeServer)
+        """
         super(PurgeHistoryRestServlet, self).__init__(hs)
         self.handlers = hs.get_handlers()
+        self.store = hs.get_datastore()
 
     @defer.inlineCallbacks
     def on_POST(self, request, room_id, event_id):
@@ -127,17 +135,114 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
         if not is_admin:
             raise AuthError(403, "You are not a server admin")
 
-        yield self.handlers.message_handler.purge_history(room_id, event_id)
+        body = parse_json_object_from_request(request, allow_empty_body=True)
+
+        delete_local_events = bool(body.get("delete_local_events", False))
+
+        # establish the topological ordering we should keep events from. The
+        # user can provide an event_id in the URL or the request body, or can
+        # provide a timestamp in the request body.
+        if event_id is None:
+            event_id = body.get('purge_up_to_event_id')
+
+        if event_id is not None:
+            event = yield self.store.get_event(event_id)
+
+            if event.room_id != room_id:
+                raise SynapseError(400, "Event is for wrong room.")
+
+            depth = event.depth
+            logger.info(
+                "[purge] purging up to depth %i (event_id %s)",
+                depth, event_id,
+            )
+        elif 'purge_up_to_ts' in body:
+            ts = body['purge_up_to_ts']
+            if not isinstance(ts, int):
+                raise SynapseError(
+                    400, "purge_up_to_ts must be an int",
+                    errcode=Codes.BAD_JSON,
+                )
+
+            stream_ordering = (
+                yield self.store.find_first_stream_ordering_after_ts(ts)
+            )
+
+            room_event_after_stream_ordering = (
+                yield self.store.get_room_event_after_stream_ordering(
+                    room_id, stream_ordering,
+                )
+            )
+            if room_event_after_stream_ordering:
+                (_, depth, _) = room_event_after_stream_ordering
+            else:
+                logger.warn(
+                    "[purge] purging events not possible: No event found "
+                    "(received_ts %i => stream_ordering %i)",
+                    ts, stream_ordering,
+                )
+                raise SynapseError(
+                    404,
+                    "there is no event to be purged",
+                    errcode=Codes.NOT_FOUND,
+                )
+            logger.info(
+                "[purge] purging up to depth %i (received_ts %i => "
+                "stream_ordering %i)",
+                depth, ts, stream_ordering,
+            )
+        else:
+            raise SynapseError(
+                400,
+                "must specify purge_up_to_event_id or purge_up_to_ts",
+                errcode=Codes.BAD_JSON,
+            )
+
+        purge_id = yield self.handlers.message_handler.start_purge_history(
+            room_id, depth,
+            delete_local_events=delete_local_events,
+        )
 
-        defer.returnValue((200, {}))
+        defer.returnValue((200, {
+            "purge_id": purge_id,
+        }))
+
+
+class PurgeHistoryStatusRestServlet(ClientV1RestServlet):
+    PATTERNS = client_path_patterns(
+        "/admin/purge_history_status/(?P<purge_id>[^/]+)"
+    )
+
+    def __init__(self, hs):
+        """
+
+        Args:
+            hs (synapse.server.HomeServer)
+        """
+        super(PurgeHistoryStatusRestServlet, self).__init__(hs)
+        self.handlers = hs.get_handlers()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, purge_id):
+        requester = yield self.auth.get_user_by_req(request)
+        is_admin = yield self.auth.is_server_admin(requester.user)
+
+        if not is_admin:
+            raise AuthError(403, "You are not a server admin")
+
+        purge_status = self.handlers.message_handler.get_purge_status(purge_id)
+        if purge_status is None:
+            raise NotFoundError("purge id '%s' not found" % purge_id)
+
+        defer.returnValue((200, purge_status.asdict()))
 
 
 class DeactivateAccountRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
 
     def __init__(self, hs):
-        self.store = hs.get_datastore()
         super(DeactivateAccountRestServlet, self).__init__(hs)
+        self._deactivate_account_handler = hs.get_deactivate_account_handler()
 
     @defer.inlineCallbacks
     def on_POST(self, request, target_user_id):
@@ -148,18 +253,171 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
         if not is_admin:
             raise AuthError(403, "You are not a server admin")
 
-        # FIXME: Theoretically there is a race here wherein user resets password
-        # using threepid.
-        yield self.store.user_delete_access_tokens(target_user_id)
-        yield self.store.user_delete_threepids(target_user_id)
-        yield self.store.user_set_password_hash(target_user_id, None)
-
+        yield self._deactivate_account_handler.deactivate_account(target_user_id)
         defer.returnValue((200, {}))
 
 
+class ShutdownRoomRestServlet(ClientV1RestServlet):
+    """Shuts down a room by removing all local users from the room and blocking
+    all future invites and joins to the room. Any local aliases will be repointed
+    to a new room created by `new_room_user_id` and kicked users will be auto
+    joined to the new room.
+    """
+    PATTERNS = client_path_patterns("/admin/shutdown_room/(?P<room_id>[^/]+)")
+
+    DEFAULT_MESSAGE = (
+        "Sharing illegal content on this server is not permitted and rooms in"
+        " violation will be blocked."
+    )
+
+    def __init__(self, hs):
+        super(ShutdownRoomRestServlet, self).__init__(hs)
+        self.store = hs.get_datastore()
+        self.handlers = hs.get_handlers()
+        self.state = hs.get_state_handler()
+        self.event_creation_handler = hs.get_event_creation_handler()
+        self.room_member_handler = hs.get_room_member_handler()
+
+    @defer.inlineCallbacks
+    def on_POST(self, request, room_id):
+        requester = yield self.auth.get_user_by_req(request)
+        is_admin = yield self.auth.is_server_admin(requester.user)
+        if not is_admin:
+            raise AuthError(403, "You are not a server admin")
+
+        content = parse_json_object_from_request(request)
+
+        new_room_user_id = content.get("new_room_user_id")
+        if not new_room_user_id:
+            raise SynapseError(400, "Please provide field `new_room_user_id`")
+
+        room_creator_requester = create_requester(new_room_user_id)
+
+        message = content.get("message", self.DEFAULT_MESSAGE)
+        room_name = content.get("room_name", "Content Violation Notification")
+
+        info = yield self.handlers.room_creation_handler.create_room(
+            room_creator_requester,
+            config={
+                "preset": "public_chat",
+                "name": room_name,
+                "power_level_content_override": {
+                    "users_default": -10,
+                },
+            },
+            ratelimit=False,
+        )
+        new_room_id = info["room_id"]
+
+        yield self.event_creation_handler.create_and_send_nonmember_event(
+            room_creator_requester,
+            {
+                "type": "m.room.message",
+                "content": {"body": message, "msgtype": "m.text"},
+                "room_id": new_room_id,
+                "sender": new_room_user_id,
+            },
+            ratelimit=False,
+        )
+
+        requester_user_id = requester.user.to_string()
+
+        logger.info("Shutting down room %r", room_id)
+
+        yield self.store.block_room(room_id, requester_user_id)
+
+        users = yield self.state.get_current_user_in_room(room_id)
+        kicked_users = []
+        for user_id in users:
+            if not self.hs.is_mine_id(user_id):
+                continue
+
+            logger.info("Kicking %r from %r...", user_id, room_id)
+
+            target_requester = create_requester(user_id)
+            yield self.room_member_handler.update_membership(
+                requester=target_requester,
+                target=target_requester.user,
+                room_id=room_id,
+                action=Membership.LEAVE,
+                content={},
+                ratelimit=False
+            )
+
+            yield self.room_member_handler.forget(target_requester.user, room_id)
+
+            yield self.room_member_handler.update_membership(
+                requester=target_requester,
+                target=target_requester.user,
+                room_id=new_room_id,
+                action=Membership.JOIN,
+                content={},
+                ratelimit=False
+            )
+
+            kicked_users.append(user_id)
+
+        aliases_for_room = yield self.store.get_aliases_for_room(room_id)
+
+        yield self.store.update_aliases_for_room(
+            room_id, new_room_id, requester_user_id
+        )
+
+        defer.returnValue((200, {
+            "kicked_users": kicked_users,
+            "local_aliases": aliases_for_room,
+            "new_room_id": new_room_id,
+        }))
+
+
+class QuarantineMediaInRoom(ClientV1RestServlet):
+    """Quarantines all media in a room so that no one can download it via
+    this server.
+    """
+    PATTERNS = client_path_patterns("/admin/quarantine_media/(?P<room_id>[^/]+)")
+
+    def __init__(self, hs):
+        super(QuarantineMediaInRoom, self).__init__(hs)
+        self.store = hs.get_datastore()
+
+    @defer.inlineCallbacks
+    def on_POST(self, request, room_id):
+        requester = yield self.auth.get_user_by_req(request)
+        is_admin = yield self.auth.is_server_admin(requester.user)
+        if not is_admin:
+            raise AuthError(403, "You are not a server admin")
+
+        num_quarantined = yield self.store.quarantine_media_ids_in_room(
+            room_id, requester.user.to_string(),
+        )
+
+        defer.returnValue((200, {"num_quarantined": num_quarantined}))
+
+
+class ListMediaInRoom(ClientV1RestServlet):
+    """Lists all of the media in a given room.
+    """
+    PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media")
+
+    def __init__(self, hs):
+        super(ListMediaInRoom, self).__init__(hs)
+        self.store = hs.get_datastore()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, room_id):
+        requester = yield self.auth.get_user_by_req(request)
+        is_admin = yield self.auth.is_server_admin(requester.user)
+        if not is_admin:
+            raise AuthError(403, "You are not a server admin")
+
+        local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
+
+        defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
+
+
 class ResetPasswordRestServlet(ClientV1RestServlet):
     """Post request to allow an administrator reset password for a user.
-    This need a user have a administrator access in Synapse.
+    This needs user to have administrator access in Synapse.
         Example:
             http://localhost:8008/_matrix/client/api/v1/admin/reset_password/
             @user:to_reset_password?access_token=admin_access_token
@@ -177,12 +435,12 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
         super(ResetPasswordRestServlet, self).__init__(hs)
         self.hs = hs
         self.auth = hs.get_auth()
-        self.auth_handler = hs.get_auth_handler()
+        self._set_password_handler = hs.get_set_password_handler()
 
     @defer.inlineCallbacks
     def on_POST(self, request, target_user_id):
         """Post request to allow an administrator reset password for a user.
-        This need a user have a administrator access in Synapse.
+        This needs user to have administrator access in Synapse.
         """
         UserID.from_string(target_user_id)
         requester = yield self.auth.get_user_by_req(request)
@@ -198,7 +456,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
 
         logger.info("new_password: %r", new_password)
 
-        yield self.auth_handler.set_password(
+        yield self._set_password_handler.set_password(
             target_user_id, new_password, requester
         )
         defer.returnValue((200, {}))
@@ -206,7 +464,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
 
 class GetUsersPaginatedRestServlet(ClientV1RestServlet):
     """Get request to get specific number of users from Synapse.
-    This need a user have a administrator access in Synapse.
+    This needs user to have administrator access in Synapse.
         Example:
             http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/
             @admin:user?access_token=admin_access_token&start=0&limit=10
@@ -225,7 +483,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
     @defer.inlineCallbacks
     def on_GET(self, request, target_user_id):
         """Get request to get specific number of users from Synapse.
-        This need a user have a administrator access in Synapse.
+        This needs user to have administrator access in Synapse.
         """
         target_user = UserID.from_string(target_user_id)
         requester = yield self.auth.get_user_by_req(request)
@@ -258,7 +516,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
     @defer.inlineCallbacks
     def on_POST(self, request, target_user_id):
         """Post request to get specific number of users from Synapse..
-        This need a user have a administrator access in Synapse.
+        This needs user to have administrator access in Synapse.
         Example:
             http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/
             @admin:user?access_token=admin_access_token
@@ -296,7 +554,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
 class SearchUsersRestServlet(ClientV1RestServlet):
     """Get request to search user table for specific users according to
     search term.
-    This need a user have a administrator access in Synapse.
+    This needs user to have administrator access in Synapse.
         Example:
             http://localhost:8008/_matrix/client/api/v1/admin/search_users/
             @admin:user?access_token=admin_access_token&term=alice
@@ -316,7 +574,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
     def on_GET(self, request, target_user_id):
         """Get request to search user table for specific users according to
         search term.
-        This need a user have a administrator access in Synapse.
+        This needs user to have a administrator access in Synapse.
         """
         target_user = UserID.from_string(target_user_id)
         requester = yield self.auth.get_user_by_req(request)
@@ -347,9 +605,13 @@ class SearchUsersRestServlet(ClientV1RestServlet):
 def register_servlets(hs, http_server):
     WhoisRestServlet(hs).register(http_server)
     PurgeMediaCacheRestServlet(hs).register(http_server)
+    PurgeHistoryStatusRestServlet(hs).register(http_server)
     DeactivateAccountRestServlet(hs).register(http_server)
     PurgeHistoryRestServlet(hs).register(http_server)
     UsersRestServlet(hs).register(http_server)
     ResetPasswordRestServlet(hs).register(http_server)
     GetUsersPaginatedRestServlet(hs).register(http_server)
     SearchUsersRestServlet(hs).register(http_server)
+    ShutdownRoomRestServlet(hs).register(http_server)
+    QuarantineMediaInRoom(hs).register(http_server)
+    ListMediaInRoom(hs).register(http_server)
diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py
index c7aa0bbf59..197335d7aa 100644
--- a/synapse/rest/client/v1/base.py
+++ b/synapse/rest/client/v1/base.py
@@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet):
     """A base Synapse REST Servlet for the client version 1 API.
     """
 
+    # This subclass was presumably created to allow the auth for the v1
+    # protocol version to be different, however this behaviour was removed.
+    # it may no longer be necessary
+
     def __init__(self, hs):
         """
         Args:
@@ -59,5 +63,5 @@ class ClientV1RestServlet(RestServlet):
         """
         self.hs = hs
         self.builder_factory = hs.get_event_builder_factory()
-        self.auth = hs.get_v1auth()
+        self.auth = hs.get_auth()
         self.txns = HttpTransactionCache(hs.get_clock())
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 8930f1826f..1c3933380f 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -39,6 +39,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
 
     def __init__(self, hs):
         super(ClientDirectoryServer, self).__init__(hs)
+        self.store = hs.get_datastore()
         self.handlers = hs.get_handlers()
 
     @defer.inlineCallbacks
@@ -70,7 +71,10 @@ class ClientDirectoryServer(ClientV1RestServlet):
         logger.debug("Got servers: %s", servers)
 
         # TODO(erikj): Check types.
-        # TODO(erikj): Check that room exists
+
+        room = yield self.store.get_room(room_id)
+        if room is None:
+            raise SynapseError(400, "Room does not exist")
 
         dir_handler = self.handlers.directory_handler
 
@@ -89,7 +93,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
                 )
             except SynapseError as e:
                 raise e
-            except:
+            except Exception:
                 logger.exception("Failed to create association")
                 raise
         except AuthError:
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 72057f1b0c..34df5be4e9 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -19,12 +19,13 @@ from synapse.api.errors import SynapseError, LoginError, Codes
 from synapse.types import UserID
 from synapse.http.server import finish_request
 from synapse.http.servlet import parse_json_object_from_request
+from synapse.util.msisdn import phone_number_to_msisdn
 
 from .base import ClientV1RestServlet, client_path_patterns
 
 import simplejson as json
 import urllib
-import urlparse
+from six.moves.urllib import parse as urlparse
 
 import logging
 from saml2 import BINDING_HTTP_POST
@@ -33,13 +34,57 @@ from saml2.client import Saml2Client
 
 import xml.etree.ElementTree as ET
 
+from twisted.web.client import PartialDownloadError
+
 
 logger = logging.getLogger(__name__)
 
 
+def login_submission_legacy_convert(submission):
+    """
+    If the input login submission is an old style object
+    (ie. with top-level user / medium / address) convert it
+    to a typed object.
+    """
+    if "user" in submission:
+        submission["identifier"] = {
+            "type": "m.id.user",
+            "user": submission["user"],
+        }
+        del submission["user"]
+
+    if "medium" in submission and "address" in submission:
+        submission["identifier"] = {
+            "type": "m.id.thirdparty",
+            "medium": submission["medium"],
+            "address": submission["address"],
+        }
+        del submission["medium"]
+        del submission["address"]
+
+
+def login_id_thirdparty_from_phone(identifier):
+    """
+    Convert a phone login identifier type to a generic threepid identifier
+    Args:
+        identifier(dict): Login identifier dict of type 'm.id.phone'
+
+    Returns: Login identifier dict of type 'm.id.threepid'
+    """
+    if "country" not in identifier or "number" not in identifier:
+        raise SynapseError(400, "Invalid phone-type identifier")
+
+    msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"])
+
+    return {
+        "type": "m.id.thirdparty",
+        "medium": "msisdn",
+        "address": msisdn,
+    }
+
+
 class LoginRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/login$")
-    PASS_TYPE = "m.login.password"
     SAML2_TYPE = "m.login.saml2"
     CAS_TYPE = "m.login.cas"
     TOKEN_TYPE = "m.login.token"
@@ -48,7 +93,6 @@ class LoginRestServlet(ClientV1RestServlet):
     def __init__(self, hs):
         super(LoginRestServlet, self).__init__(hs)
         self.idp_redirect_url = hs.config.saml2_idp_redirect_url
-        self.password_enabled = hs.config.password_enabled
         self.saml2_enabled = hs.config.saml2_enabled
         self.jwt_enabled = hs.config.jwt_enabled
         self.jwt_secret = hs.config.jwt_secret
@@ -75,8 +119,10 @@ class LoginRestServlet(ClientV1RestServlet):
             # fall back to the fallback API if they don't understand one of the
             # login flow types returned.
             flows.append({"type": LoginRestServlet.TOKEN_TYPE})
-        if self.password_enabled:
-            flows.append({"type": LoginRestServlet.PASS_TYPE})
+
+        flows.extend((
+            {"type": t} for t in self.auth_handler.get_supported_login_types()
+        ))
 
         return (200, {"flows": flows})
 
@@ -87,14 +133,8 @@ class LoginRestServlet(ClientV1RestServlet):
     def on_POST(self, request):
         login_submission = parse_json_object_from_request(request)
         try:
-            if login_submission["type"] == LoginRestServlet.PASS_TYPE:
-                if not self.password_enabled:
-                    raise SynapseError(400, "Password login has been disabled.")
-
-                result = yield self.do_password_login(login_submission)
-                defer.returnValue(result)
-            elif self.saml2_enabled and (login_submission["type"] ==
-                                         LoginRestServlet.SAML2_TYPE):
+            if self.saml2_enabled and (login_submission["type"] ==
+                                       LoginRestServlet.SAML2_TYPE):
                 relay_state = ""
                 if "relay_state" in login_submission:
                     relay_state = "&RelayState=" + urllib.quote(
@@ -111,49 +151,102 @@ class LoginRestServlet(ClientV1RestServlet):
                 result = yield self.do_token_login(login_submission)
                 defer.returnValue(result)
             else:
-                raise SynapseError(400, "Bad login type.")
+                result = yield self._do_other_login(login_submission)
+                defer.returnValue(result)
         except KeyError:
             raise SynapseError(400, "Missing JSON keys.")
 
     @defer.inlineCallbacks
-    def do_password_login(self, login_submission):
-        if 'medium' in login_submission and 'address' in login_submission:
-            address = login_submission['address']
-            if login_submission['medium'] == 'email':
+    def _do_other_login(self, login_submission):
+        """Handle non-token/saml/jwt logins
+
+        Args:
+            login_submission:
+
+        Returns:
+            (int, object): HTTP code/response
+        """
+        # Log the request we got, but only certain fields to minimise the chance of
+        # logging someone's password (even if they accidentally put it in the wrong
+        # field)
+        logger.info(
+            "Got login request with identifier: %r, medium: %r, address: %r, user: %r",
+            login_submission.get('identifier'),
+            login_submission.get('medium'),
+            login_submission.get('address'),
+            login_submission.get('user'),
+        )
+        login_submission_legacy_convert(login_submission)
+
+        if "identifier" not in login_submission:
+            raise SynapseError(400, "Missing param: identifier")
+
+        identifier = login_submission["identifier"]
+        if "type" not in identifier:
+            raise SynapseError(400, "Login identifier has no type")
+
+        # convert phone type identifiers to generic threepids
+        if identifier["type"] == "m.id.phone":
+            identifier = login_id_thirdparty_from_phone(identifier)
+
+        # convert threepid identifiers to user IDs
+        if identifier["type"] == "m.id.thirdparty":
+            address = identifier.get('address')
+            medium = identifier.get('medium')
+
+            if medium is None or address is None:
+                raise SynapseError(400, "Invalid thirdparty identifier")
+
+            if medium == 'email':
                 # For emails, transform the address to lowercase.
                 # We store all email addreses as lowercase in the DB.
                 # (See add_threepid in synapse/handlers/auth.py)
                 address = address.lower()
             user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
-                login_submission['medium'], address
+                medium, address,
             )
             if not user_id:
+                logger.warn(
+                    "unknown 3pid identifier medium %s, address %r",
+                    medium, address,
+                )
                 raise LoginError(403, "", errcode=Codes.FORBIDDEN)
-        else:
-            user_id = login_submission['user']
 
-        if not user_id.startswith('@'):
-            user_id = UserID.create(
-                user_id, self.hs.hostname
-            ).to_string()
+            identifier = {
+                "type": "m.id.user",
+                "user": user_id,
+            }
+
+        # by this point, the identifier should be an m.id.user: if it's anything
+        # else, we haven't understood it.
+        if identifier["type"] != "m.id.user":
+            raise SynapseError(400, "Unknown login identifier type")
+        if "user" not in identifier:
+            raise SynapseError(400, "User identifier is missing 'user' key")
 
         auth_handler = self.auth_handler
-        user_id = yield auth_handler.validate_password_login(
-            user_id=user_id,
-            password=login_submission["password"],
+        canonical_user_id, callback = yield auth_handler.validate_login(
+            identifier["user"],
+            login_submission,
+        )
+
+        device_id = yield self._register_device(
+            canonical_user_id, login_submission,
         )
-        device_id = yield self._register_device(user_id, login_submission)
         access_token = yield auth_handler.get_access_token_for_user_id(
-            user_id, device_id,
-            login_submission.get("initial_device_display_name"),
+            canonical_user_id, device_id,
         )
+
         result = {
-            "user_id": user_id,  # may have changed
+            "user_id": canonical_user_id,
             "access_token": access_token,
             "home_server": self.hs.hostname,
             "device_id": device_id,
         }
 
+        if callback is not None:
+            yield callback(result)
+
         defer.returnValue((200, result))
 
     @defer.inlineCallbacks
@@ -166,7 +259,6 @@ class LoginRestServlet(ClientV1RestServlet):
         device_id = yield self._register_device(user_id, login_submission)
         access_token = yield auth_handler.get_access_token_for_user_id(
             user_id, device_id,
-            login_submission.get("initial_device_display_name"),
         )
         result = {
             "user_id": user_id,  # may have changed
@@ -200,7 +292,7 @@ class LoginRestServlet(ClientV1RestServlet):
         if user is None:
             raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
 
-        user_id = UserID.create(user, self.hs.hostname).to_string()
+        user_id = UserID(user, self.hs.hostname).to_string()
         auth_handler = self.auth_handler
         registered_user_id = yield auth_handler.check_user_exists(user_id)
         if registered_user_id:
@@ -209,7 +301,6 @@ class LoginRestServlet(ClientV1RestServlet):
             )
             access_token = yield auth_handler.get_access_token_for_user_id(
                 registered_user_id, device_id,
-                login_submission.get("initial_device_display_name"),
             )
 
             result = {
@@ -341,7 +432,12 @@ class CasTicketServlet(ClientV1RestServlet):
             "ticket": request.args["ticket"],
             "service": self.cas_service_url
         }
-        body = yield http_client.get_raw(uri, args)
+        try:
+            body = yield http_client.get_raw(uri, args)
+        except PartialDownloadError as pde:
+            # Twisted raises this error if the connection is closed,
+            # even if that's being used old-http style to signal end-of-data
+            body = pde.response
         result = yield self.handle_cas_response(request, body, client_redirect_url)
         defer.returnValue(result)
 
@@ -361,7 +457,7 @@ class CasTicketServlet(ClientV1RestServlet):
                 if required_value != actual_value:
                     raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
 
-        user_id = UserID.create(user, self.hs.hostname).to_string()
+        user_id = UserID(user, self.hs.hostname).to_string()
         auth_handler = self.auth_handler
         registered_user_id = yield auth_handler.check_user_exists(user_id)
         if not registered_user_id:
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 1358d0acab..e092158cb7 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -16,6 +16,7 @@
 from twisted.internet import defer
 
 from synapse.api.auth import get_access_token_from_request
+from synapse.api.errors import AuthError
 
 from .base import ClientV1RestServlet, client_path_patterns
 
@@ -30,15 +31,33 @@ class LogoutRestServlet(ClientV1RestServlet):
 
     def __init__(self, hs):
         super(LogoutRestServlet, self).__init__(hs)
-        self.store = hs.get_datastore()
+        self._auth = hs.get_auth()
+        self._auth_handler = hs.get_auth_handler()
+        self._device_handler = hs.get_device_handler()
 
     def on_OPTIONS(self, request):
         return (200, {})
 
     @defer.inlineCallbacks
     def on_POST(self, request):
-        access_token = get_access_token_from_request(request)
-        yield self.store.delete_access_token(access_token)
+        try:
+            requester = yield self.auth.get_user_by_req(request)
+        except AuthError:
+            # this implies the access token has already been deleted.
+            defer.returnValue((401, {
+                "errcode": "M_UNKNOWN_TOKEN",
+                "error": "Access Token unknown or expired"
+            }))
+        else:
+            if requester.device_id is None:
+                # the acccess token wasn't associated with a device.
+                # Just delete the access token
+                access_token = get_access_token_from_request(request)
+                yield self._auth_handler.delete_access_token(access_token)
+            else:
+                yield self._device_handler.delete_device(
+                    requester.user.to_string(), requester.device_id)
+
         defer.returnValue((200, {}))
 
 
@@ -47,8 +66,9 @@ class LogoutAllRestServlet(ClientV1RestServlet):
 
     def __init__(self, hs):
         super(LogoutAllRestServlet, self).__init__(hs)
-        self.store = hs.get_datastore()
         self.auth = hs.get_auth()
+        self._auth_handler = hs.get_auth_handler()
+        self._device_handler = hs.get_device_handler()
 
     def on_OPTIONS(self, request):
         return (200, {})
@@ -57,7 +77,13 @@ class LogoutAllRestServlet(ClientV1RestServlet):
     def on_POST(self, request):
         requester = yield self.auth.get_user_by_req(request)
         user_id = requester.user.to_string()
-        yield self.store.user_delete_access_tokens(user_id)
+
+        # first delete all of the user's devices
+        yield self._device_handler.delete_all_devices_for_user(user_id)
+
+        # .. and then delete any access tokens which weren't associated with
+        # devices.
+        yield self._auth_handler.delete_access_tokens_for_user(user_id)
         defer.returnValue((200, {}))
 
 
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index eafdce865e..4a73813c58 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError, AuthError
 from synapse.types import UserID
+from synapse.handlers.presence import format_user_presence_state
 from synapse.http.servlet import parse_json_object_from_request
 from .base import ClientV1RestServlet, client_path_patterns
 
@@ -33,6 +34,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
     def __init__(self, hs):
         super(PresenceStatusRestServlet, self).__init__(hs)
         self.presence_handler = hs.get_presence_handler()
+        self.clock = hs.get_clock()
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id):
@@ -48,6 +50,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
                 raise AuthError(403, "You are not allowed to see their presence.")
 
         state = yield self.presence_handler.get_state(target_user=user)
+        state = format_user_presence_state(state, self.clock.time_msec())
 
         defer.returnValue((200, state))
 
@@ -75,7 +78,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
                 raise KeyError()
         except SynapseError as e:
             raise e
-        except:
+        except Exception:
             raise SynapseError(400, "Unable to parse state")
 
         yield self.presence_handler.set_state(user, state)
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 1a5045c9ec..e4e3611a14 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -26,13 +26,13 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
 
     def __init__(self, hs):
         super(ProfileDisplaynameRestServlet, self).__init__(hs)
-        self.handlers = hs.get_handlers()
+        self.profile_handler = hs.get_profile_handler()
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id):
         user = UserID.from_string(user_id)
 
-        displayname = yield self.handlers.profile_handler.get_displayname(
+        displayname = yield self.profile_handler.get_displayname(
             user,
         )
 
@@ -52,10 +52,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
 
         try:
             new_name = content["displayname"]
-        except:
+        except Exception:
             defer.returnValue((400, "Unable to parse name"))
 
-        yield self.handlers.profile_handler.set_displayname(
+        yield self.profile_handler.set_displayname(
             user, requester, new_name, is_admin)
 
         defer.returnValue((200, {}))
@@ -69,13 +69,13 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
 
     def __init__(self, hs):
         super(ProfileAvatarURLRestServlet, self).__init__(hs)
-        self.handlers = hs.get_handlers()
+        self.profile_handler = hs.get_profile_handler()
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id):
         user = UserID.from_string(user_id)
 
-        avatar_url = yield self.handlers.profile_handler.get_avatar_url(
+        avatar_url = yield self.profile_handler.get_avatar_url(
             user,
         )
 
@@ -94,10 +94,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
         content = parse_json_object_from_request(request)
         try:
             new_name = content["avatar_url"]
-        except:
+        except Exception:
             defer.returnValue((400, "Unable to parse name"))
 
-        yield self.handlers.profile_handler.set_avatar_url(
+        yield self.profile_handler.set_avatar_url(
             user, requester, new_name, is_admin)
 
         defer.returnValue((200, {}))
@@ -111,16 +111,16 @@ class ProfileRestServlet(ClientV1RestServlet):
 
     def __init__(self, hs):
         super(ProfileRestServlet, self).__init__(hs)
-        self.handlers = hs.get_handlers()
+        self.profile_handler = hs.get_profile_handler()
 
     @defer.inlineCallbacks
     def on_GET(self, request, user_id):
         user = UserID.from_string(user_id)
 
-        displayname = yield self.handlers.profile_handler.get_displayname(
+        displayname = yield self.profile_handler.get_displayname(
             user,
         )
-        avatar_url = yield self.handlers.profile_handler.get_avatar_url(
+        avatar_url = yield self.profile_handler.get_avatar_url(
             user,
         )
 
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 9a2ed6ed88..0206e664c1 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -73,6 +73,7 @@ class PushersSetRestServlet(ClientV1RestServlet):
     def __init__(self, hs):
         super(PushersSetRestServlet, self).__init__(hs)
         self.notifier = hs.get_notifier()
+        self.pusher_pool = self.hs.get_pusherpool()
 
     @defer.inlineCallbacks
     def on_POST(self, request):
@@ -81,12 +82,10 @@ class PushersSetRestServlet(ClientV1RestServlet):
 
         content = parse_json_object_from_request(request)
 
-        pusher_pool = self.hs.get_pusherpool()
-
         if ('pushkey' in content and 'app_id' in content
                 and 'kind' in content and
                 content['kind'] is None):
-            yield pusher_pool.remove_pusher(
+            yield self.pusher_pool.remove_pusher(
                 content['app_id'], content['pushkey'], user_id=user.to_string()
             )
             defer.returnValue((200, {}))
@@ -109,14 +108,14 @@ class PushersSetRestServlet(ClientV1RestServlet):
             append = content['append']
 
         if not append:
-            yield pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
+            yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
                 app_id=content['app_id'],
                 pushkey=content['pushkey'],
                 not_user_id=user.to_string()
             )
 
         try:
-            yield pusher_pool.add_pusher(
+            yield self.pusher_pool.add_pusher(
                 user_id=user.to_string(),
                 access_token=requester.access_token_id,
                 kind=content['kind'],
@@ -151,7 +150,8 @@ class PushersRemoveRestServlet(RestServlet):
         super(RestServlet, self).__init__()
         self.hs = hs
         self.notifier = hs.get_notifier()
-        self.auth = hs.get_v1auth()
+        self.auth = hs.get_auth()
+        self.pusher_pool = self.hs.get_pusherpool()
 
     @defer.inlineCallbacks
     def on_GET(self, request):
@@ -161,10 +161,8 @@ class PushersRemoveRestServlet(RestServlet):
         app_id = parse_string(request, "app_id", required=True)
         pushkey = parse_string(request, "pushkey", required=True)
 
-        pusher_pool = self.hs.get_pusherpool()
-
         try:
-            yield pusher_pool.remove_pusher(
+            yield self.pusher_pool.remove_pusher(
                 app_id=app_id,
                 pushkey=pushkey,
                 user_id=user.to_string(),
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index ecf7e311a9..9b3022e0b0 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -30,6 +30,8 @@ from hashlib import sha1
 import hmac
 import logging
 
+from six import string_types
+
 logger = logging.getLogger(__name__)
 
 
@@ -70,10 +72,15 @@ class RegisterRestServlet(ClientV1RestServlet):
         self.handlers = hs.get_handlers()
 
     def on_GET(self, request):
+
+        require_email = 'email' in self.hs.config.registrations_require_3pid
+        require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
+
+        flows = []
         if self.hs.config.enable_registration_captcha:
-            return (
-                200,
-                {"flows": [
+            # only support the email-only flow if we don't require MSISDN 3PIDs
+            if not require_msisdn:
+                flows.extend([
                     {
                         "type": LoginType.RECAPTCHA,
                         "stages": [
@@ -82,27 +89,34 @@ class RegisterRestServlet(ClientV1RestServlet):
                             LoginType.PASSWORD
                         ]
                     },
+                ])
+            # only support 3PIDless registration if no 3PIDs are required
+            if not require_email and not require_msisdn:
+                flows.extend([
                     {
                         "type": LoginType.RECAPTCHA,
                         "stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
                     }
-                ]}
-            )
+                ])
         else:
-            return (
-                200,
-                {"flows": [
+            # only support the email-only flow if we don't require MSISDN 3PIDs
+            if require_email or not require_msisdn:
+                flows.extend([
                     {
                         "type": LoginType.EMAIL_IDENTITY,
                         "stages": [
                             LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
                         ]
-                    },
+                    }
+                ])
+            # only support 3PIDless registration if no 3PIDs are required
+            if not require_email and not require_msisdn:
+                flows.extend([
                     {
                         "type": LoginType.PASSWORD
                     }
-                ]}
-            )
+                ])
+        return (200, {"flows": flows})
 
     @defer.inlineCallbacks
     def on_POST(self, request):
@@ -321,11 +335,11 @@ class RegisterRestServlet(ClientV1RestServlet):
     def _do_shared_secret(self, request, register_json, session):
         yield run_on_reactor()
 
-        if not isinstance(register_json.get("mac", None), basestring):
+        if not isinstance(register_json.get("mac", None), string_types):
             raise SynapseError(400, "Expected mac.")
-        if not isinstance(register_json.get("user", None), basestring):
+        if not isinstance(register_json.get("user", None), string_types):
             raise SynapseError(400, "Expected 'user' key.")
-        if not isinstance(register_json.get("password", None), basestring):
+        if not isinstance(register_json.get("password", None), string_types):
             raise SynapseError(400, "Expected 'password' key.")
 
         if not self.hs.config.registration_shared_secret:
@@ -336,9 +350,9 @@ class RegisterRestServlet(ClientV1RestServlet):
         admin = register_json.get("admin", None)
 
         # Its important to check as we use null bytes as HMAC field separators
-        if "\x00" in user:
+        if b"\x00" in user:
             raise SynapseError(400, "Invalid user")
-        if "\x00" in password:
+        if b"\x00" in password:
             raise SynapseError(400, "Invalid password")
 
         # str() because otherwise hmac complains that 'unicode' does not
@@ -346,20 +360,20 @@ class RegisterRestServlet(ClientV1RestServlet):
         got_mac = str(register_json["mac"])
 
         want_mac = hmac.new(
-            key=self.hs.config.registration_shared_secret,
+            key=self.hs.config.registration_shared_secret.encode(),
             digestmod=sha1,
         )
         want_mac.update(user)
-        want_mac.update("\x00")
+        want_mac.update(b"\x00")
         want_mac.update(password)
-        want_mac.update("\x00")
-        want_mac.update("admin" if admin else "notadmin")
+        want_mac.update(b"\x00")
+        want_mac.update(b"admin" if admin else b"notadmin")
         want_mac = want_mac.hexdigest()
 
         if compare_digest(want_mac, got_mac):
             handler = self.handlers.registration_handler
             user_id, token = yield handler.register(
-                localpart=user,
+                localpart=user.lower(),
                 password=password,
                 admin=bool(admin),
             )
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 728e3df0e3..fcf9c9ab44 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -27,9 +28,10 @@ from synapse.http.servlet import (
     parse_json_object_from_request, parse_string, parse_integer
 )
 
+from six.moves.urllib import parse as urlparse
+
 import logging
-import urllib
-import ujson as json
+import simplejson as json
 
 logger = logging.getLogger(__name__)
 
@@ -82,6 +84,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
     def __init__(self, hs):
         super(RoomStateEventRestServlet, self).__init__(hs)
         self.handlers = hs.get_handlers()
+        self.event_creation_hander = hs.get_event_creation_handler()
+        self.room_member_handler = hs.get_room_member_handler()
 
     def register(self, http_server):
         # /room/$roomid/state/$eventtype
@@ -154,7 +158,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
 
         if event_type == EventTypes.Member:
             membership = content.get("membership", None)
-            event = yield self.handlers.room_member_handler.update_membership(
+            event = yield self.room_member_handler.update_membership(
                 requester,
                 target=UserID.from_string(state_key),
                 room_id=room_id,
@@ -162,15 +166,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
                 content=content,
             )
         else:
-            msg_handler = self.handlers.message_handler
-            event, context = yield msg_handler.create_event(
+            event = yield self.event_creation_hander.create_and_send_nonmember_event(
+                requester,
                 event_dict,
-                token_id=requester.access_token_id,
                 txn_id=txn_id,
             )
 
-            yield msg_handler.send_nonmember_event(requester, event, context)
-
         ret = {}
         if event:
             ret = {"event_id": event.event_id}
@@ -182,7 +183,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
 
     def __init__(self, hs):
         super(RoomSendEventRestServlet, self).__init__(hs)
-        self.handlers = hs.get_handlers()
+        self.event_creation_hander = hs.get_event_creation_handler()
 
     def register(self, http_server):
         # /rooms/$roomid/send/$event_type[/$txn_id]
@@ -194,15 +195,19 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
         requester = yield self.auth.get_user_by_req(request, allow_guest=True)
         content = parse_json_object_from_request(request)
 
-        msg_handler = self.handlers.message_handler
-        event = yield msg_handler.create_and_send_nonmember_event(
+        event_dict = {
+            "type": event_type,
+            "content": content,
+            "room_id": room_id,
+            "sender": requester.user.to_string(),
+        }
+
+        if 'ts' in request.args and requester.app_service:
+            event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
+
+        event = yield self.event_creation_hander.create_and_send_nonmember_event(
             requester,
-            {
-                "type": event_type,
-                "content": content,
-                "room_id": room_id,
-                "sender": requester.user.to_string(),
-            },
+            event_dict,
             txn_id=txn_id,
         )
 
@@ -221,7 +226,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
 class JoinRoomAliasServlet(ClientV1RestServlet):
     def __init__(self, hs):
         super(JoinRoomAliasServlet, self).__init__(hs)
-        self.handlers = hs.get_handlers()
+        self.room_member_handler = hs.get_room_member_handler()
 
     def register(self, http_server):
         # /join/$room_identifier[/$txn_id]
@@ -237,7 +242,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
 
         try:
             content = parse_json_object_from_request(request)
-        except:
+        except Exception:
             # Turns out we used to ignore the body entirely, and some clients
             # cheekily send invalid bodies.
             content = {}
@@ -246,10 +251,10 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
             room_id = room_identifier
             try:
                 remote_room_hosts = request.args["server_name"]
-            except:
+            except Exception:
                 remote_room_hosts = None
         elif RoomAlias.is_valid(room_identifier):
-            handler = self.handlers.room_member_handler
+            handler = self.room_member_handler
             room_alias = RoomAlias.from_string(room_identifier)
             room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
             room_id = room_id.to_string()
@@ -258,7 +263,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
                 room_identifier,
             ))
 
-        yield self.handlers.room_member_handler.update_membership(
+        yield self.room_member_handler.update_membership(
             requester=requester,
             target=requester.user,
             room_id=room_id,
@@ -397,16 +402,18 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
 
     def __init__(self, hs):
         super(JoinedRoomMemberListRestServlet, self).__init__(hs)
-        self.state = hs.get_state_handler()
+        self.message_handler = hs.get_handlers().message_handler
 
     @defer.inlineCallbacks
     def on_GET(self, request, room_id):
-        yield self.auth.get_user_by_req(request)
+        requester = yield self.auth.get_user_by_req(request)
 
-        users_with_profile = yield self.state.get_current_user_in_room(room_id)
+        users_with_profile = yield self.message_handler.get_joined_members(
+            requester, room_id,
+        )
 
         defer.returnValue((200, {
-            "joined": users_with_profile
+            "joined": users_with_profile,
         }))
 
 
@@ -427,7 +434,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
         as_client_event = "raw" not in request.args
         filter_bytes = request.args.get("filter", None)
         if filter_bytes:
-            filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8")
+            filter_json = urlparse.unquote(filter_bytes[-1]).decode("UTF-8")
             event_filter = Filter(json.loads(filter_json))
         else:
             event_filter = None
@@ -484,13 +491,35 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
         defer.returnValue((200, content))
 
 
-class RoomEventContext(ClientV1RestServlet):
+class RoomEventServlet(ClientV1RestServlet):
+    PATTERNS = client_path_patterns(
+        "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$"
+    )
+
+    def __init__(self, hs):
+        super(RoomEventServlet, self).__init__(hs)
+        self.clock = hs.get_clock()
+        self.event_handler = hs.get_event_handler()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, room_id, event_id):
+        requester = yield self.auth.get_user_by_req(request)
+        event = yield self.event_handler.get_event(requester.user, event_id)
+
+        time_now = self.clock.time_msec()
+        if event:
+            defer.returnValue((200, serialize_event(event, time_now)))
+        else:
+            defer.returnValue((404, "Event not found."))
+
+
+class RoomEventContextServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns(
         "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
     )
 
     def __init__(self, hs):
-        super(RoomEventContext, self).__init__(hs)
+        super(RoomEventContextServlet, self).__init__(hs)
         self.clock = hs.get_clock()
         self.handlers = hs.get_handlers()
 
@@ -505,7 +534,6 @@ class RoomEventContext(ClientV1RestServlet):
             room_id,
             event_id,
             limit,
-            requester.is_guest,
         )
 
         if not results:
@@ -531,7 +559,7 @@ class RoomEventContext(ClientV1RestServlet):
 class RoomForgetRestServlet(ClientV1RestServlet):
     def __init__(self, hs):
         super(RoomForgetRestServlet, self).__init__(hs)
-        self.handlers = hs.get_handlers()
+        self.room_member_handler = hs.get_room_member_handler()
 
     def register(self, http_server):
         PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
@@ -544,7 +572,7 @@ class RoomForgetRestServlet(ClientV1RestServlet):
             allow_guest=False,
         )
 
-        yield self.handlers.room_member_handler.forget(
+        yield self.room_member_handler.forget(
             user=requester.user,
             room_id=room_id,
         )
@@ -562,12 +590,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
 
     def __init__(self, hs):
         super(RoomMembershipRestServlet, self).__init__(hs)
-        self.handlers = hs.get_handlers()
+        self.room_member_handler = hs.get_room_member_handler()
 
     def register(self, http_server):
         # /rooms/$roomid/[invite|join|leave]
         PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
-                    "(?P<membership_action>join|invite|leave|ban|unban|kick|forget)")
+                    "(?P<membership_action>join|invite|leave|ban|unban|kick)")
         register_txn_path(self, PATTERNS, http_server)
 
     @defer.inlineCallbacks
@@ -585,13 +613,13 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
 
         try:
             content = parse_json_object_from_request(request)
-        except:
+        except Exception:
             # Turns out we used to ignore the body entirely, and some clients
             # cheekily send invalid bodies.
             content = {}
 
         if membership_action == "invite" and self._has_3pid_invite_keys(content):
-            yield self.handlers.room_member_handler.do_3pid_invite(
+            yield self.room_member_handler.do_3pid_invite(
                 room_id,
                 requester.user,
                 content["medium"],
@@ -613,7 +641,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
         if 'reason' in content and membership_action in ['kick', 'ban']:
             event_content = {'reason': content['reason']}
 
-        yield self.handlers.room_member_handler.update_membership(
+        yield self.room_member_handler.update_membership(
             requester=requester,
             target=target,
             room_id=room_id,
@@ -623,7 +651,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
             content=event_content,
         )
 
-        defer.returnValue((200, {}))
+        return_value = {}
+
+        if membership_action == "join":
+            return_value["room_id"] = room_id
+
+        defer.returnValue((200, return_value))
 
     def _has_3pid_invite_keys(self, content):
         for key in {"id_server", "medium", "address"}:
@@ -641,6 +674,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
     def __init__(self, hs):
         super(RoomRedactEventRestServlet, self).__init__(hs)
         self.handlers = hs.get_handlers()
+        self.event_creation_handler = hs.get_event_creation_handler()
 
     def register(self, http_server):
         PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
@@ -651,8 +685,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
         requester = yield self.auth.get_user_by_req(request)
         content = parse_json_object_from_request(request)
 
-        msg_handler = self.handlers.message_handler
-        event = yield msg_handler.create_and_send_nonmember_event(
+        event = yield self.event_creation_handler.create_and_send_nonmember_event(
             requester,
             {
                 "type": EventTypes.Redaction,
@@ -686,8 +719,8 @@ class RoomTypingRestServlet(ClientV1RestServlet):
     def on_PUT(self, request, room_id, user_id):
         requester = yield self.auth.get_user_by_req(request)
 
-        room_id = urllib.unquote(room_id)
-        target_user = UserID.from_string(urllib.unquote(user_id))
+        room_id = urlparse.unquote(room_id)
+        target_user = UserID.from_string(urlparse.unquote(user_id))
 
         content = parse_json_object_from_request(request)
 
@@ -749,8 +782,7 @@ class JoinedRoomsRestServlet(ClientV1RestServlet):
     def on_GET(self, request):
         requester = yield self.auth.get_user_by_req(request, allow_guest=True)
 
-        rooms = yield self.store.get_rooms_for_user(requester.user.to_string())
-        room_ids = set(r.room_id for r in rooms)  # Ensure they're unique.
+        room_ids = yield self.store.get_rooms_for_user(requester.user.to_string())
         defer.returnValue((200, {"joined_rooms": list(room_ids)}))
 
 
@@ -802,4 +834,5 @@ def register_servlets(hs, http_server):
     RoomTypingRestServlet(hs).register(http_server)
     SearchRestServlet(hs).register(http_server)
     JoinedRoomsRestServlet(hs).register(http_server)
-    RoomEventContext(hs).register(http_server)
+    RoomEventServlet(hs).register(http_server)
+    RoomEventContextServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 03141c623c..c43b30b73a 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -28,7 +28,10 @@ class VoipRestServlet(ClientV1RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request):
-        requester = yield self.auth.get_user_by_req(request)
+        requester = yield self.auth.get_user_by_req(
+            request,
+            self.hs.config.turn_allow_guests
+        )
 
         turnUris = self.hs.config.turn_uris
         turnSecret = self.hs.config.turn_shared_secret
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index 20e765f48f..77434937ff 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -15,12 +15,13 @@
 
 """This module contains base REST classes for constructing client v1 servlets.
 """
-
-from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
+import logging
 import re
 
-import logging
+from twisted.internet import defer
 
+from synapse.api.errors import InteractiveAuthIncompleteError
+from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
 
 logger = logging.getLogger(__name__)
 
@@ -47,3 +48,47 @@ def client_v2_patterns(path_regex, releases=(0,),
         new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release)
         patterns.append(re.compile("^" + new_prefix + path_regex))
     return patterns
+
+
+def set_timeline_upper_limit(filter_json, filter_timeline_limit):
+    if filter_timeline_limit < 0:
+        return  # no upper limits
+    timeline = filter_json.get('room', {}).get('timeline', {})
+    if 'limit' in timeline:
+        filter_json['room']['timeline']["limit"] = min(
+            filter_json['room']['timeline']['limit'],
+            filter_timeline_limit)
+
+
+def interactive_auth_handler(orig):
+    """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
+
+    Takes a on_POST method which returns a deferred (errcode, body) response
+    and adds exception handling to turn a InteractiveAuthIncompleteError into
+    a 401 response.
+
+    Normal usage is:
+
+    @interactive_auth_handler
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        # ...
+        yield self.auth_handler.check_auth
+            """
+    def wrapped(*args, **kwargs):
+        res = defer.maybeDeferred(orig, *args, **kwargs)
+        res.addErrback(_catch_incomplete_interactive_auth)
+        return res
+    return wrapped
+
+
+def _catch_incomplete_interactive_auth(f):
+    """helper for interactive_auth_handler
+
+    Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
+
+    Args:
+        f (failure.Failure):
+    """
+    f.trap(InteractiveAuthIncompleteError)
+    return 401, f.value.result
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 398e7f5eb0..30523995af 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,27 +13,30 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
 
 from twisted.internet import defer
 
+from synapse.api.auth import has_access_token
 from synapse.api.constants import LoginType
-from synapse.api.errors import LoginError, SynapseError, Codes
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import (
+    RestServlet, assert_params_in_request,
+    parse_json_object_from_request,
+)
 from synapse.util.async import run_on_reactor
-
-from ._base import client_v2_patterns
-
-import logging
-
+from synapse.util.msisdn import phone_number_to_msisdn
+from synapse.util.threepids import check_3pid_allowed
+from ._base import client_v2_patterns, interactive_auth_handler
 
 logger = logging.getLogger(__name__)
 
 
-class PasswordRequestTokenRestServlet(RestServlet):
+class EmailPasswordRequestTokenRestServlet(RestServlet):
     PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
 
     def __init__(self, hs):
-        super(PasswordRequestTokenRestServlet, self).__init__()
+        super(EmailPasswordRequestTokenRestServlet, self).__init__()
         self.hs = hs
         self.identity_handler = hs.get_handlers().identity_handler
 
@@ -40,14 +44,14 @@ class PasswordRequestTokenRestServlet(RestServlet):
     def on_POST(self, request):
         body = parse_json_object_from_request(request)
 
-        required = ['id_server', 'client_secret', 'email', 'send_attempt']
-        absent = []
-        for k in required:
-            if k not in body:
-                absent.append(k)
+        assert_params_in_request(body, [
+            'id_server', 'client_secret', 'email', 'send_attempt'
+        ])
 
-        if absent:
-            raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+        if not check_3pid_allowed(self.hs, "email", body['email']):
+            raise SynapseError(
+                403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+            )
 
         existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
             'email', body['email']
@@ -60,6 +64,42 @@ class PasswordRequestTokenRestServlet(RestServlet):
         defer.returnValue((200, ret))
 
 
+class MsisdnPasswordRequestTokenRestServlet(RestServlet):
+    PATTERNS = client_v2_patterns("/account/password/msisdn/requestToken$")
+
+    def __init__(self, hs):
+        super(MsisdnPasswordRequestTokenRestServlet, self).__init__()
+        self.hs = hs
+        self.datastore = self.hs.get_datastore()
+        self.identity_handler = hs.get_handlers().identity_handler
+
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        body = parse_json_object_from_request(request)
+
+        assert_params_in_request(body, [
+            'id_server', 'client_secret',
+            'country', 'phone_number', 'send_attempt',
+        ])
+
+        msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+
+        if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+            raise SynapseError(
+                403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+            )
+
+        existingUid = yield self.datastore.get_user_id_by_threepid(
+            'msisdn', msisdn
+        )
+
+        if existingUid is None:
+            raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND)
+
+        ret = yield self.identity_handler.requestMsisdnToken(**body)
+        defer.returnValue((200, ret))
+
+
 class PasswordRestServlet(RestServlet):
     PATTERNS = client_v2_patterns("/account/password$")
 
@@ -68,55 +108,62 @@ class PasswordRestServlet(RestServlet):
         self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
+        self.datastore = self.hs.get_datastore()
+        self._set_password_handler = hs.get_set_password_handler()
 
+    @interactive_auth_handler
     @defer.inlineCallbacks
     def on_POST(self, request):
-        yield run_on_reactor()
-
         body = parse_json_object_from_request(request)
 
-        authed, result, params, _ = yield self.auth_handler.check_auth([
-            [LoginType.PASSWORD],
-            [LoginType.EMAIL_IDENTITY]
-        ], body, self.hs.get_ip_from_request(request))
-
-        if not authed:
-            defer.returnValue((401, result))
-
-        user_id = None
-        requester = None
-
-        if LoginType.PASSWORD in result:
-            # if using password, they should also be logged in
+        # there are two possibilities here. Either the user does not have an
+        # access token, and needs to do a password reset; or they have one and
+        # need to validate their identity.
+        #
+        # In the first case, we offer a couple of means of identifying
+        # themselves (email and msisdn, though it's unclear if msisdn actually
+        # works).
+        #
+        # In the second case, we require a password to confirm their identity.
+
+        if has_access_token(request):
             requester = yield self.auth.get_user_by_req(request)
-            user_id = requester.user.to_string()
-            if user_id != result[LoginType.PASSWORD]:
-                raise LoginError(400, "", Codes.UNKNOWN)
-        elif LoginType.EMAIL_IDENTITY in result:
-            threepid = result[LoginType.EMAIL_IDENTITY]
-            if 'medium' not in threepid or 'address' not in threepid:
-                raise SynapseError(500, "Malformed threepid")
-            if threepid['medium'] == 'email':
-                # For emails, transform the address to lowercase.
-                # We store all email addreses as lowercase in the DB.
-                # (See add_threepid in synapse/handlers/auth.py)
-                threepid['address'] = threepid['address'].lower()
-            # if using email, we must know about the email they're authing with!
-            threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
-                threepid['medium'], threepid['address']
+            params = yield self.auth_handler.validate_user_via_ui_auth(
+                requester, body, self.hs.get_ip_from_request(request),
             )
-            if not threepid_user_id:
-                raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
-            user_id = threepid_user_id
+            user_id = requester.user.to_string()
         else:
-            logger.error("Auth succeeded but no known type!", result.keys())
-            raise SynapseError(500, "", Codes.UNKNOWN)
+            requester = None
+            result, params, _ = yield self.auth_handler.check_auth(
+                [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]],
+                body, self.hs.get_ip_from_request(request),
+            )
+
+            if LoginType.EMAIL_IDENTITY in result:
+                threepid = result[LoginType.EMAIL_IDENTITY]
+                if 'medium' not in threepid or 'address' not in threepid:
+                    raise SynapseError(500, "Malformed threepid")
+                if threepid['medium'] == 'email':
+                    # For emails, transform the address to lowercase.
+                    # We store all email addreses as lowercase in the DB.
+                    # (See add_threepid in synapse/handlers/auth.py)
+                    threepid['address'] = threepid['address'].lower()
+                # if using email, we must know about the email they're authing with!
+                threepid_user_id = yield self.datastore.get_user_id_by_threepid(
+                    threepid['medium'], threepid['address']
+                )
+                if not threepid_user_id:
+                    raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
+                user_id = threepid_user_id
+            else:
+                logger.error("Auth succeeded but no known type!", result.keys())
+                raise SynapseError(500, "", Codes.UNKNOWN)
 
         if 'new_password' not in params:
             raise SynapseError(400, "", Codes.MISSING_PARAM)
         new_password = params['new_password']
 
-        yield self.auth_handler.set_password(
+        yield self._set_password_handler.set_password(
             user_id, new_password, requester
         )
 
@@ -130,52 +177,43 @@ class DeactivateAccountRestServlet(RestServlet):
     PATTERNS = client_v2_patterns("/account/deactivate$")
 
     def __init__(self, hs):
+        super(DeactivateAccountRestServlet, self).__init__()
         self.hs = hs
-        self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
-        super(DeactivateAccountRestServlet, self).__init__()
+        self._deactivate_account_handler = hs.get_deactivate_account_handler()
 
+    @interactive_auth_handler
     @defer.inlineCallbacks
     def on_POST(self, request):
         body = parse_json_object_from_request(request)
 
-        authed, result, params, _ = yield self.auth_handler.check_auth([
-            [LoginType.PASSWORD],
-        ], body, self.hs.get_ip_from_request(request))
-
-        if not authed:
-            defer.returnValue((401, result))
-
-        user_id = None
-        requester = None
-
-        if LoginType.PASSWORD in result:
-            # if using password, they should also be logged in
-            requester = yield self.auth.get_user_by_req(request)
-            user_id = requester.user.to_string()
-            if user_id != result[LoginType.PASSWORD]:
-                raise LoginError(400, "", Codes.UNKNOWN)
-        else:
-            logger.error("Auth succeeded but no known type!", result.keys())
-            raise SynapseError(500, "", Codes.UNKNOWN)
+        requester = yield self.auth.get_user_by_req(request)
 
-        # FIXME: Theoretically there is a race here wherein user resets password
-        # using threepid.
-        yield self.store.user_delete_access_tokens(user_id)
-        yield self.store.user_delete_threepids(user_id)
-        yield self.store.user_set_password_hash(user_id, None)
+        # allow ASes to dectivate their own users
+        if requester.app_service:
+            yield self._deactivate_account_handler.deactivate_account(
+                requester.user.to_string()
+            )
+            defer.returnValue((200, {}))
 
+        yield self.auth_handler.validate_user_via_ui_auth(
+            requester, body, self.hs.get_ip_from_request(request),
+        )
+        yield self._deactivate_account_handler.deactivate_account(
+            requester.user.to_string(),
+        )
         defer.returnValue((200, {}))
 
 
-class ThreepidRequestTokenRestServlet(RestServlet):
+class EmailThreepidRequestTokenRestServlet(RestServlet):
     PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
 
     def __init__(self, hs):
         self.hs = hs
-        super(ThreepidRequestTokenRestServlet, self).__init__()
+        super(EmailThreepidRequestTokenRestServlet, self).__init__()
         self.identity_handler = hs.get_handlers().identity_handler
+        self.datastore = self.hs.get_datastore()
 
     @defer.inlineCallbacks
     def on_POST(self, request):
@@ -190,7 +228,12 @@ class ThreepidRequestTokenRestServlet(RestServlet):
         if absent:
             raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
 
-        existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
+        if not check_3pid_allowed(self.hs, "email", body['email']):
+            raise SynapseError(
+                403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+            )
+
+        existingUid = yield self.datastore.get_user_id_by_threepid(
             'email', body['email']
         )
 
@@ -201,6 +244,49 @@ class ThreepidRequestTokenRestServlet(RestServlet):
         defer.returnValue((200, ret))
 
 
+class MsisdnThreepidRequestTokenRestServlet(RestServlet):
+    PATTERNS = client_v2_patterns("/account/3pid/msisdn/requestToken$")
+
+    def __init__(self, hs):
+        self.hs = hs
+        super(MsisdnThreepidRequestTokenRestServlet, self).__init__()
+        self.identity_handler = hs.get_handlers().identity_handler
+        self.datastore = self.hs.get_datastore()
+
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        body = parse_json_object_from_request(request)
+
+        required = [
+            'id_server', 'client_secret',
+            'country', 'phone_number', 'send_attempt',
+        ]
+        absent = []
+        for k in required:
+            if k not in body:
+                absent.append(k)
+
+        if absent:
+            raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+
+        msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+
+        if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+            raise SynapseError(
+                403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+            )
+
+        existingUid = yield self.datastore.get_user_id_by_threepid(
+            'msisdn', msisdn
+        )
+
+        if existingUid is not None:
+            raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
+
+        ret = yield self.identity_handler.requestMsisdnToken(**body)
+        defer.returnValue((200, ret))
+
+
 class ThreepidRestServlet(RestServlet):
     PATTERNS = client_v2_patterns("/account/3pid$")
 
@@ -210,6 +296,7 @@ class ThreepidRestServlet(RestServlet):
         self.identity_handler = hs.get_handlers().identity_handler
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
+        self.datastore = self.hs.get_datastore()
 
     @defer.inlineCallbacks
     def on_GET(self, request):
@@ -217,7 +304,7 @@ class ThreepidRestServlet(RestServlet):
 
         requester = yield self.auth.get_user_by_req(request)
 
-        threepids = yield self.hs.get_datastore().user_get_threepids(
+        threepids = yield self.datastore.user_get_threepids(
             requester.user.to_string()
         )
 
@@ -258,7 +345,7 @@ class ThreepidRestServlet(RestServlet):
 
         if 'bind' in body and body['bind']:
             logger.debug(
-                "Binding emails %s to %s",
+                "Binding threepid %s to %s",
                 threepid, user_id
             )
             yield self.identity_handler.bind_threepid(
@@ -301,10 +388,27 @@ class ThreepidDeleteRestServlet(RestServlet):
         defer.returnValue((200, {}))
 
 
+class WhoamiRestServlet(RestServlet):
+    PATTERNS = client_v2_patterns("/account/whoami$")
+
+    def __init__(self, hs):
+        super(WhoamiRestServlet, self).__init__()
+        self.auth = hs.get_auth()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request):
+        requester = yield self.auth.get_user_by_req(request)
+
+        defer.returnValue((200, {'user_id': requester.user.to_string()}))
+
+
 def register_servlets(hs, http_server):
-    PasswordRequestTokenRestServlet(hs).register(http_server)
+    EmailPasswordRequestTokenRestServlet(hs).register(http_server)
+    MsisdnPasswordRequestTokenRestServlet(hs).register(http_server)
     PasswordRestServlet(hs).register(http_server)
     DeactivateAccountRestServlet(hs).register(http_server)
-    ThreepidRequestTokenRestServlet(hs).register(http_server)
+    EmailThreepidRequestTokenRestServlet(hs).register(http_server)
+    MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
     ThreepidRestServlet(hs).register(http_server)
     ThreepidDeleteRestServlet(hs).register(http_server)
+    WhoamiRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index b16079cece..0e0a187efd 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -16,7 +16,7 @@
 from ._base import client_v2_patterns
 
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.api.errors import AuthError
+from synapse.api.errors import AuthError, SynapseError
 
 from twisted.internet import defer
 
@@ -82,6 +82,13 @@ class RoomAccountDataServlet(RestServlet):
 
         body = parse_json_object_from_request(request)
 
+        if account_data_type == "m.fully_read":
+            raise SynapseError(
+                405,
+                "Cannot set m.fully_read through this API."
+                " Use /rooms/!roomId:server.name/read_markers"
+            )
+
         max_id = yield self.store.add_account_data_to_room(
             user_id, room_id, account_data_type, body
         )
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index a1feaf3d54..35d58b367a 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -17,15 +17,15 @@ import logging
 
 from twisted.internet import defer
 
-from synapse.api import constants, errors
+from synapse.api import errors
 from synapse.http import servlet
-from ._base import client_v2_patterns
+from ._base import client_v2_patterns, interactive_auth_handler
 
 logger = logging.getLogger(__name__)
 
 
 class DevicesRestServlet(servlet.RestServlet):
-    PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
+    PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
 
     def __init__(self, hs):
         """
@@ -46,9 +46,53 @@ class DevicesRestServlet(servlet.RestServlet):
         defer.returnValue((200, {"devices": devices}))
 
 
+class DeleteDevicesRestServlet(servlet.RestServlet):
+    """
+    API for bulk deletion of devices. Accepts a JSON object with a devices
+    key which lists the device_ids to delete. Requires user interactive auth.
+    """
+    PATTERNS = client_v2_patterns("/delete_devices", v2_alpha=False)
+
+    def __init__(self, hs):
+        super(DeleteDevicesRestServlet, self).__init__()
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.device_handler = hs.get_device_handler()
+        self.auth_handler = hs.get_auth_handler()
+
+    @interactive_auth_handler
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        requester = yield self.auth.get_user_by_req(request)
+
+        try:
+            body = servlet.parse_json_object_from_request(request)
+        except errors.SynapseError as e:
+            if e.errcode == errors.Codes.NOT_JSON:
+                # deal with older clients which didn't pass a J*DELETESON dict
+                # the same as those that pass an empty dict
+                body = {}
+            else:
+                raise e
+
+        if 'devices' not in body:
+            raise errors.SynapseError(
+                400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
+            )
+
+        yield self.auth_handler.validate_user_via_ui_auth(
+            requester, body, self.hs.get_ip_from_request(request),
+        )
+
+        yield self.device_handler.delete_devices(
+            requester.user.to_string(),
+            body['devices'],
+        )
+        defer.returnValue((200, {}))
+
+
 class DeviceRestServlet(servlet.RestServlet):
-    PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
-                                  releases=[], v2_alpha=False)
+    PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
 
     def __init__(self, hs):
         """
@@ -70,8 +114,11 @@ class DeviceRestServlet(servlet.RestServlet):
         )
         defer.returnValue((200, device))
 
+    @interactive_auth_handler
     @defer.inlineCallbacks
     def on_DELETE(self, request, device_id):
+        requester = yield self.auth.get_user_by_req(request)
+
         try:
             body = servlet.parse_json_object_from_request(request)
 
@@ -83,17 +130,12 @@ class DeviceRestServlet(servlet.RestServlet):
             else:
                 raise
 
-        authed, result, params, _ = yield self.auth_handler.check_auth([
-            [constants.LoginType.PASSWORD],
-        ], body, self.hs.get_ip_from_request(request))
-
-        if not authed:
-            defer.returnValue((401, result))
+        yield self.auth_handler.validate_user_via_ui_auth(
+            requester, body, self.hs.get_ip_from_request(request),
+        )
 
-        requester = yield self.auth.get_user_by_req(request)
         yield self.device_handler.delete_device(
-            requester.user.to_string(),
-            device_id,
+            requester.user.to_string(), device_id,
         )
         defer.returnValue((200, {}))
 
@@ -111,5 +153,6 @@ class DeviceRestServlet(servlet.RestServlet):
 
 
 def register_servlets(hs, http_server):
+    DeleteDevicesRestServlet(hs).register(http_server)
     DevicesRestServlet(hs).register(http_server)
     DeviceRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index b4084fec62..1b9dc4528d 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -20,6 +20,7 @@ from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.types import UserID
 
 from ._base import client_v2_patterns
+from ._base import set_timeline_upper_limit
 
 import logging
 
@@ -49,7 +50,7 @@ class GetFilterRestServlet(RestServlet):
 
         try:
             filter_id = int(filter_id)
-        except:
+        except Exception:
             raise SynapseError(400, "Invalid filter_id")
 
         try:
@@ -85,6 +86,11 @@ class CreateFilterRestServlet(RestServlet):
             raise AuthError(403, "Can only create filters for local users")
 
         content = parse_json_object_from_request(request)
+        set_timeline_upper_limit(
+            content,
+            self.hs.config.filter_timeline_limit
+        )
+
         filter_id = yield self.filtering.add_user_filter(
             user_localpart=target_user.localpart,
             user_filter=content,
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
new file mode 100644
index 0000000000..3bb1ec2af6
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -0,0 +1,786 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import GroupID
+
+from ._base import client_v2_patterns
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class GroupServlet(RestServlet):
+    """Get the group profile
+    """
+    PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/profile$")
+
+    def __init__(self, hs):
+        super(GroupServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.groups_handler = hs.get_groups_local_handler()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, group_id):
+        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+        requester_user_id = requester.user.to_string()
+
+        group_description = yield self.groups_handler.get_group_profile(
+            group_id,
+            requester_user_id,
+        )
+
+        defer.returnValue((200, group_description))
+
+    @defer.inlineCallbacks
+    def on_POST(self, request, group_id):
+        requester = yield self.auth.get_user_by_req(request)
+        requester_user_id = requester.user.to_string()
+
+        content = parse_json_object_from_request(request)
+        yield self.groups_handler.update_group_profile(
+            group_id, requester_user_id, content,
+        )
+
+        defer.returnValue((200, {}))
+
+
+class GroupSummaryServlet(RestServlet):
+    """Get the full group summary
+    """
+    PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/summary$")
+
+    def __init__(self, hs):
+        super(GroupSummaryServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.groups_handler = hs.get_groups_local_handler()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, group_id):
+        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+        requester_user_id = requester.user.to_string()
+
+        get_group_summary = yield self.groups_handler.get_group_summary(
+            group_id,
+            requester_user_id,
+        )
+
+        defer.returnValue((200, get_group_summary))
+
+
+class GroupSummaryRoomsCatServlet(RestServlet):
+    """Update/delete a rooms entry in the summary.
+
+    Matches both:
+        - /groups/:group/summary/rooms/:room_id
+        - /groups/:group/summary/categories/:category/rooms/:room_id
+    """
+    PATTERNS = client_v2_patterns(
+        "/groups/(?P<group_id>[^/]*)/summary"
+        "(/categories/(?P<category_id>[^/]+))?"
+        "/rooms/(?P<room_id>[^/]*)$"
+    )
+
+    def __init__(self, hs):
+        super(GroupSummaryRoomsCatServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.groups_handler = hs.get_groups_local_handler()
+
+    @defer.inlineCallbacks
+    def on_PUT(self, request, group_id, category_id, room_id):
+        requester = yield self.auth.get_user_by_req(request)
+        requester_user_id = requester.user.to_string()
+
+        content = parse_json_object_from_request(request)
+        resp = yield self.groups_handler.update_group_summary_room(
+            group_id, requester_user_id,
+            room_id=room_id,
+            category_id=category_id,
+            content=content,
+        )
+
+        defer.returnValue((200, resp))
+
+    @defer.inlineCallbacks
+    def on_DELETE(self, request, group_id, category_id, room_id):
+        requester = yield self.auth.get_user_by_req(request)
+        requester_user_id = requester.user.to_string()
+
+        resp = yield self.groups_handler.delete_group_summary_room(
+            group_id, requester_user_id,
+            room_id=room_id,
+            category_id=category_id,
+        )
+
+        defer.returnValue((200, resp))
+
+
+class GroupCategoryServlet(RestServlet):
+    """Get/add/update/delete a group category
+    """
+    PATTERNS = client_v2_patterns(
+        "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
+    )
+
+    def __init__(self, hs):
+        super(GroupCategoryServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.groups_handler = hs.get_groups_local_handler()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, group_id, category_id):
+        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+        requester_user_id = requester.user.to_string()
+
+        category = yield self.groups_handler.get_group_category(
+            group_id, requester_user_id,
+            category_id=category_id,
+        )
+
+        defer.returnValue((200, category))
+
+    @defer.inlineCallbacks
+    def on_PUT(self, request, group_id, category_id):
+        requester = yield self.auth.get_user_by_req(request)
+        requester_user_id = requester.user.to_string()
+
+        content = parse_json_object_from_request(request)
+        resp = yield self.groups_handler.update_group_category(
+            group_id, requester_user_id,
+            category_id=category_id,
+            content=content,
+        )
+
+        defer.returnValue((200, resp))
+
+    @defer.inlineCallbacks
+    def on_DELETE(self, request, group_id, category_id):
+        requester = yield self.auth.get_user_by_req(request)
+        requester_user_id = requester.user.to_string()
+
+        resp = yield self.groups_handler.delete_group_category(
+            group_id, requester_user_id,
+            category_id=category_id,
+        )
+
+        defer.returnValue((200, resp))
+
+
+class GroupCategoriesServlet(RestServlet):
+    """Get all group categories
+    """
+    PATTERNS = client_v2_patterns(
+        "/groups/(?P<group_id>[^/]*)/categories/$"
+    )
+
+    def __init__(self, hs):
+        super(GroupCategoriesServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.groups_handler = hs.get_groups_local_handler()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, group_id):
+        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+        requester_user_id = requester.user.to_string()
+
+        category = yield self.groups_handler.get_group_categories(
+            group_id, requester_user_id,
+        )
+
+        defer.returnValue((200, category))
+
+
+class GroupRoleServlet(RestServlet):
+    """Get/add/update/delete a group role
+    """
+    PATTERNS = client_v2_patterns(
+        "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
+    )
+
+    def __init__(self, hs):
+        super(GroupRoleServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.groups_handler = hs.get_groups_local_handler()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, group_id, role_id):
+        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+        requester_user_id = requester.user.to_string()
+
+        category = yield self.groups_handler.get_group_role(
+            group_id, requester_user_id,
+            role_id=role_id,
+        )
+
+        defer.returnValue((200, category))
+
+    @defer.inlineCallbacks
+    def on_PUT(self, request, group_id, role_id):
+        requester = yield self.auth.get_user_by_req(request)
+        requester_user_id = requester.user.to_string()
+
+        content = parse_json_object_from_request(request)
+        resp = yield self.groups_handler.update_group_role(
+            group_id, requester_user_id,
+            role_id=role_id,
+            content=content,
+        )
+
+        defer.returnValue((200, resp))
+
+    @defer.inlineCallbacks
+    def on_DELETE(self, request, group_id, role_id):
+        requester = yield self.auth.get_user_by_req(request)
+        requester_user_id = requester.user.to_string()
+
+        resp = yield self.groups_handler.delete_group_role(
+            group_id, requester_user_id,
+            role_id=role_id,
+        )
+
+        defer.returnValue((200, resp))
+
+
+class GroupRolesServlet(RestServlet):
+    """Get all group roles
+    """
+    PATTERNS = client_v2_patterns(
+        "/groups/(?P<group_id>[^/]*)/roles/$"
+    )
+
+    def __init__(self, hs):
+        super(GroupRolesServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.groups_handler = hs.get_groups_local_handler()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, group_id):
+        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+        requester_user_id = requester.user.to_string()
+
+        category = yield self.groups_handler.get_group_roles(
+            group_id, requester_user_id,
+        )
+
+        defer.returnValue((200, category))
+
+
+class GroupSummaryUsersRoleServlet(RestServlet):
+    """Update/delete a user's entry in the summary.
+
+    Matches both:
+        - /groups/:group/summary/users/:room_id
+        - /groups/:group/summary/roles/:role/users/:user_id
+    """
+    PATTERNS = client_v2_patterns(
+        "/groups/(?P<group_id>[^/]*)/summary"
+        "(/roles/(?P<role_id>[^/]+))?"
+        "/users/(?P<user_id>[^/]*)$"
+    )
+
+    def __init__(self, hs):
+        super(GroupSummaryUsersRoleServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.groups_handler = hs.get_groups_local_handler()
+
+    @defer.inlineCallbacks
+    def on_PUT(self, request, group_id, role_id, user_id):
+        requester = yield self.auth.get_user_by_req(request)
+        requester_user_id = requester.user.to_string()
+
+        content = parse_json_object_from_request(request)
+        resp = yield self.groups_handler.update_group_summary_user(
+            group_id, requester_user_id,
+            user_id=user_id,
+            role_id=role_id,
+            content=content,
+        )
+
+        defer.returnValue((200, resp))
+
+    @defer.inlineCallbacks
+    def on_DELETE(self, request, group_id, role_id, user_id):
+        requester = yield self.auth.get_user_by_req(request)
+        requester_user_id = requester.user.to_string()
+
+        resp = yield self.groups_handler.delete_group_summary_user(
+            group_id, requester_user_id,
+            user_id=user_id,
+            role_id=role_id,
+        )
+
+        defer.returnValue((200, resp))
+
+
+class GroupRoomServlet(RestServlet):
+    """Get all rooms in a group
+    """
+    PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
+
+    def __init__(self, hs):
+        super(GroupRoomServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.groups_handler = hs.get_groups_local_handler()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, group_id):
+        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+        requester_user_id = requester.user.to_string()
+
+        result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id)
+
+        defer.returnValue((200, result))
+
+
+class GroupUsersServlet(RestServlet):
+    """Get all users in a group
+    """
+    PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/users$")
+
+    def __init__(self, hs):
+        super(GroupUsersServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.groups_handler = hs.get_groups_local_handler()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, group_id):
+        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+        requester_user_id = requester.user.to_string()
+
+        result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id)
+
+        defer.returnValue((200, result))
+
+
+class GroupInvitedUsersServlet(RestServlet):
+    """Get users invited to a group
+    """
+    PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
+
+    def __init__(self, hs):
+        super(GroupInvitedUsersServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.groups_handler = hs.get_groups_local_handler()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, group_id):
+        requester = yield self.auth.get_user_by_req(request)
+        requester_user_id = requester.user.to_string()
+
+        result = yield self.groups_handler.get_invited_users_in_group(
+            group_id,
+            requester_user_id,
+        )
+
+        defer.returnValue((200, result))
+
+
+class GroupSettingJoinPolicyServlet(RestServlet):
+    """Set group join policy
+    """
+    PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
+
+    def __init__(self, hs):
+        super(GroupSettingJoinPolicyServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.groups_handler = hs.get_groups_local_handler()
+
+    @defer.inlineCallbacks
+    def on_PUT(self, request, group_id):
+        requester = yield self.auth.get_user_by_req(request)
+        requester_user_id = requester.user.to_string()
+
+        content = parse_json_object_from_request(request)
+
+        result = yield self.groups_handler.set_group_join_policy(
+            group_id,
+            requester_user_id,
+            content,
+        )
+
+        defer.returnValue((200, result))
+
+
+class GroupCreateServlet(RestServlet):
+    """Create a group
+    """
+    PATTERNS = client_v2_patterns("/create_group$")
+
+    def __init__(self, hs):
+        super(GroupCreateServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.groups_handler = hs.get_groups_local_handler()
+        self.server_name = hs.hostname
+
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        requester = yield self.auth.get_user_by_req(request)
+        requester_user_id = requester.user.to_string()
+
+        # TODO: Create group on remote server
+        content = parse_json_object_from_request(request)
+        localpart = content.pop("localpart")
+        group_id = GroupID(localpart, self.server_name).to_string()
+
+        result = yield self.groups_handler.create_group(
+            group_id,
+            requester_user_id,
+            content,
+        )
+
+        defer.returnValue((200, result))
+
+
+class GroupAdminRoomsServlet(RestServlet):
+    """Add a room to the group
+    """
+    PATTERNS = client_v2_patterns(
+        "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
+    )
+
+    def __init__(self, hs):
+        super(GroupAdminRoomsServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.groups_handler = hs.get_groups_local_handler()
+
+    @defer.inlineCallbacks
+    def on_PUT(self, request, group_id, room_id):
+        requester = yield self.auth.get_user_by_req(request)
+        requester_user_id = requester.user.to_string()
+
+        content = parse_json_object_from_request(request)
+        result = yield self.groups_handler.add_room_to_group(
+            group_id, requester_user_id, room_id, content,
+        )
+
+        defer.returnValue((200, result))
+
+    @defer.inlineCallbacks
+    def on_DELETE(self, request, group_id, room_id):
+        requester = yield self.auth.get_user_by_req(request)
+        requester_user_id = requester.user.to_string()
+
+        result = yield self.groups_handler.remove_room_from_group(
+            group_id, requester_user_id, room_id,
+        )
+
+        defer.returnValue((200, result))
+
+
+class GroupAdminRoomsConfigServlet(RestServlet):
+    """Update the config of a room in a group
+    """
+    PATTERNS = client_v2_patterns(
+        "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
+        "/config/(?P<config_key>[^/]*)$"
+    )
+
+    def __init__(self, hs):
+        super(GroupAdminRoomsConfigServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.groups_handler = hs.get_groups_local_handler()
+
+    @defer.inlineCallbacks
+    def on_PUT(self, request, group_id, room_id, config_key):
+        requester = yield self.auth.get_user_by_req(request)
+        requester_user_id = requester.user.to_string()
+
+        content = parse_json_object_from_request(request)
+        result = yield self.groups_handler.update_room_in_group(
+            group_id, requester_user_id, room_id, config_key, content,
+        )
+
+        defer.returnValue((200, result))
+
+
+class GroupAdminUsersInviteServlet(RestServlet):
+    """Invite a user to the group
+    """
+    PATTERNS = client_v2_patterns(
+        "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
+    )
+
+    def __init__(self, hs):
+        super(GroupAdminUsersInviteServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.groups_handler = hs.get_groups_local_handler()
+        self.store = hs.get_datastore()
+        self.is_mine_id = hs.is_mine_id
+
+    @defer.inlineCallbacks
+    def on_PUT(self, request, group_id, user_id):
+        requester = yield self.auth.get_user_by_req(request)
+        requester_user_id = requester.user.to_string()
+
+        content = parse_json_object_from_request(request)
+        config = content.get("config", {})
+        result = yield self.groups_handler.invite(
+            group_id, user_id, requester_user_id, config,
+        )
+
+        defer.returnValue((200, result))
+
+
+class GroupAdminUsersKickServlet(RestServlet):
+    """Kick a user from the group
+    """
+    PATTERNS = client_v2_patterns(
+        "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
+    )
+
+    def __init__(self, hs):
+        super(GroupAdminUsersKickServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.groups_handler = hs.get_groups_local_handler()
+
+    @defer.inlineCallbacks
+    def on_PUT(self, request, group_id, user_id):
+        requester = yield self.auth.get_user_by_req(request)
+        requester_user_id = requester.user.to_string()
+
+        content = parse_json_object_from_request(request)
+        result = yield self.groups_handler.remove_user_from_group(
+            group_id, user_id, requester_user_id, content,
+        )
+
+        defer.returnValue((200, result))
+
+
+class GroupSelfLeaveServlet(RestServlet):
+    """Leave a joined group
+    """
+    PATTERNS = client_v2_patterns(
+        "/groups/(?P<group_id>[^/]*)/self/leave$"
+    )
+
+    def __init__(self, hs):
+        super(GroupSelfLeaveServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.groups_handler = hs.get_groups_local_handler()
+
+    @defer.inlineCallbacks
+    def on_PUT(self, request, group_id):
+        requester = yield self.auth.get_user_by_req(request)
+        requester_user_id = requester.user.to_string()
+
+        content = parse_json_object_from_request(request)
+        result = yield self.groups_handler.remove_user_from_group(
+            group_id, requester_user_id, requester_user_id, content,
+        )
+
+        defer.returnValue((200, result))
+
+
+class GroupSelfJoinServlet(RestServlet):
+    """Attempt to join a group, or knock
+    """
+    PATTERNS = client_v2_patterns(
+        "/groups/(?P<group_id>[^/]*)/self/join$"
+    )
+
+    def __init__(self, hs):
+        super(GroupSelfJoinServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.groups_handler = hs.get_groups_local_handler()
+
+    @defer.inlineCallbacks
+    def on_PUT(self, request, group_id):
+        requester = yield self.auth.get_user_by_req(request)
+        requester_user_id = requester.user.to_string()
+
+        content = parse_json_object_from_request(request)
+        result = yield self.groups_handler.join_group(
+            group_id, requester_user_id, content,
+        )
+
+        defer.returnValue((200, result))
+
+
+class GroupSelfAcceptInviteServlet(RestServlet):
+    """Accept a group invite
+    """
+    PATTERNS = client_v2_patterns(
+        "/groups/(?P<group_id>[^/]*)/self/accept_invite$"
+    )
+
+    def __init__(self, hs):
+        super(GroupSelfAcceptInviteServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.groups_handler = hs.get_groups_local_handler()
+
+    @defer.inlineCallbacks
+    def on_PUT(self, request, group_id):
+        requester = yield self.auth.get_user_by_req(request)
+        requester_user_id = requester.user.to_string()
+
+        content = parse_json_object_from_request(request)
+        result = yield self.groups_handler.accept_invite(
+            group_id, requester_user_id, content,
+        )
+
+        defer.returnValue((200, result))
+
+
+class GroupSelfUpdatePublicityServlet(RestServlet):
+    """Update whether we publicise a users membership of a group
+    """
+    PATTERNS = client_v2_patterns(
+        "/groups/(?P<group_id>[^/]*)/self/update_publicity$"
+    )
+
+    def __init__(self, hs):
+        super(GroupSelfUpdatePublicityServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.store = hs.get_datastore()
+
+    @defer.inlineCallbacks
+    def on_PUT(self, request, group_id):
+        requester = yield self.auth.get_user_by_req(request)
+        requester_user_id = requester.user.to_string()
+
+        content = parse_json_object_from_request(request)
+        publicise = content["publicise"]
+        yield self.store.update_group_publicity(
+            group_id, requester_user_id, publicise,
+        )
+
+        defer.returnValue((200, {}))
+
+
+class PublicisedGroupsForUserServlet(RestServlet):
+    """Get the list of groups a user is advertising
+    """
+    PATTERNS = client_v2_patterns(
+        "/publicised_groups/(?P<user_id>[^/]*)$"
+    )
+
+    def __init__(self, hs):
+        super(PublicisedGroupsForUserServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.store = hs.get_datastore()
+        self.groups_handler = hs.get_groups_local_handler()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, user_id):
+        yield self.auth.get_user_by_req(request, allow_guest=True)
+
+        result = yield self.groups_handler.get_publicised_groups_for_user(
+            user_id
+        )
+
+        defer.returnValue((200, result))
+
+
+class PublicisedGroupsForUsersServlet(RestServlet):
+    """Get the list of groups a user is advertising
+    """
+    PATTERNS = client_v2_patterns(
+        "/publicised_groups$"
+    )
+
+    def __init__(self, hs):
+        super(PublicisedGroupsForUsersServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.store = hs.get_datastore()
+        self.groups_handler = hs.get_groups_local_handler()
+
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        yield self.auth.get_user_by_req(request, allow_guest=True)
+
+        content = parse_json_object_from_request(request)
+        user_ids = content["user_ids"]
+
+        result = yield self.groups_handler.bulk_get_publicised_groups(
+            user_ids
+        )
+
+        defer.returnValue((200, result))
+
+
+class GroupsForUserServlet(RestServlet):
+    """Get all groups the logged in user is joined to
+    """
+    PATTERNS = client_v2_patterns(
+        "/joined_groups$"
+    )
+
+    def __init__(self, hs):
+        super(GroupsForUserServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.clock = hs.get_clock()
+        self.groups_handler = hs.get_groups_local_handler()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request):
+        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+        requester_user_id = requester.user.to_string()
+
+        result = yield self.groups_handler.get_joined_groups(requester_user_id)
+
+        defer.returnValue((200, result))
+
+
+def register_servlets(hs, http_server):
+    GroupServlet(hs).register(http_server)
+    GroupSummaryServlet(hs).register(http_server)
+    GroupInvitedUsersServlet(hs).register(http_server)
+    GroupUsersServlet(hs).register(http_server)
+    GroupRoomServlet(hs).register(http_server)
+    GroupSettingJoinPolicyServlet(hs).register(http_server)
+    GroupCreateServlet(hs).register(http_server)
+    GroupAdminRoomsServlet(hs).register(http_server)
+    GroupAdminRoomsConfigServlet(hs).register(http_server)
+    GroupAdminUsersInviteServlet(hs).register(http_server)
+    GroupAdminUsersKickServlet(hs).register(http_server)
+    GroupSelfLeaveServlet(hs).register(http_server)
+    GroupSelfJoinServlet(hs).register(http_server)
+    GroupSelfAcceptInviteServlet(hs).register(http_server)
+    GroupsForUserServlet(hs).register(http_server)
+    GroupCategoryServlet(hs).register(http_server)
+    GroupCategoriesServlet(hs).register(http_server)
+    GroupSummaryRoomsCatServlet(hs).register(http_server)
+    GroupRoleServlet(hs).register(http_server)
+    GroupRolesServlet(hs).register(http_server)
+    GroupSelfUpdatePublicityServlet(hs).register(http_server)
+    GroupSummaryUsersRoleServlet(hs).register(http_server)
+    PublicisedGroupsForUserServlet(hs).register(http_server)
+    PublicisedGroupsForUsersServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 6a3cfe84f8..3cc87ea63f 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -53,8 +53,7 @@ class KeyUploadServlet(RestServlet):
       },
     }
     """
-    PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
-                                  releases=())
+    PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
 
     def __init__(self, hs):
         """
@@ -128,10 +127,7 @@ class KeyQueryServlet(RestServlet):
     } } } } } }
     """
 
-    PATTERNS = client_v2_patterns(
-        "/keys/query$",
-        releases=()
-    )
+    PATTERNS = client_v2_patterns("/keys/query$")
 
     def __init__(self, hs):
         """
@@ -160,10 +156,7 @@ class KeyChangesServlet(RestServlet):
         200 OK
         { "changed": ["@foo:example.com"] }
     """
-    PATTERNS = client_v2_patterns(
-        "/keys/changes$",
-        releases=()
-    )
+    PATTERNS = client_v2_patterns("/keys/changes$")
 
     def __init__(self, hs):
         """
@@ -188,13 +181,11 @@ class KeyChangesServlet(RestServlet):
 
         user_id = requester.user.to_string()
 
-        changed = yield self.device_handler.get_user_ids_changed(
+        results = yield self.device_handler.get_user_ids_changed(
             user_id, from_token,
         )
 
-        defer.returnValue((200, {
-            "changed": list(changed),
-        }))
+        defer.returnValue((200, results))
 
 
 class OneTimeKeyServlet(RestServlet):
@@ -215,10 +206,7 @@ class OneTimeKeyServlet(RestServlet):
     } } } }
 
     """
-    PATTERNS = client_v2_patterns(
-        "/keys/claim$",
-        releases=()
-    )
+    PATTERNS = client_v2_patterns("/keys/claim$")
 
     def __init__(self, hs):
         super(OneTimeKeyServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index fd2a3d69d4..ec170109fe 100644
--- a/synapse/rest/client/v2_alpha/notifications.py
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
 
 
 class NotificationsServlet(RestServlet):
-    PATTERNS = client_v2_patterns("/notifications$", releases=())
+    PATTERNS = client_v2_patterns("/notifications$")
 
     def __init__(self, hs):
         super(NotificationsServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py
new file mode 100644
index 0000000000..2f8784fe06
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/read_marker.py
@@ -0,0 +1,66 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from ._base import client_v2_patterns
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class ReadMarkerRestServlet(RestServlet):
+    PATTERNS = client_v2_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
+
+    def __init__(self, hs):
+        super(ReadMarkerRestServlet, self).__init__()
+        self.auth = hs.get_auth()
+        self.receipts_handler = hs.get_receipts_handler()
+        self.read_marker_handler = hs.get_read_marker_handler()
+        self.presence_handler = hs.get_presence_handler()
+
+    @defer.inlineCallbacks
+    def on_POST(self, request, room_id):
+        requester = yield self.auth.get_user_by_req(request)
+
+        yield self.presence_handler.bump_presence_active_time(requester.user)
+
+        body = parse_json_object_from_request(request)
+
+        read_event_id = body.get("m.read", None)
+        if read_event_id:
+            yield self.receipts_handler.received_client_receipt(
+                room_id,
+                "m.read",
+                user_id=requester.user.to_string(),
+                event_id=read_event_id
+            )
+
+        read_marker_event_id = body.get("m.fully_read", None)
+        if read_marker_event_id:
+            yield self.read_marker_handler.received_client_read_marker(
+                room_id,
+                user_id=requester.user.to_string(),
+                event_id=read_marker_event_id
+            )
+
+        defer.returnValue((200, {}))
+
+
+def register_servlets(hs, http_server):
+    ReadMarkerRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index ccca5a12d5..5cab00aea9 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015 - 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,17 +17,25 @@
 from twisted.internet import defer
 
 import synapse
+import synapse.types
 from synapse.api.auth import get_access_token_from_request, has_access_token
 from synapse.api.constants import LoginType
 from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.servlet import (
+    RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
+)
+from synapse.util.msisdn import phone_number_to_msisdn
+from synapse.util.threepids import check_3pid_allowed
 
-from ._base import client_v2_patterns
+from ._base import client_v2_patterns, interactive_auth_handler
 
 import logging
 import hmac
 from hashlib import sha1
 from synapse.util.async import run_on_reactor
+from synapse.util.ratelimitutils import FederationRateLimiter
+
+from six import string_types
 
 
 # We ought to be using hmac.compare_digest() but on older pythons it doesn't
@@ -43,7 +52,7 @@ else:
 logger = logging.getLogger(__name__)
 
 
-class RegisterRequestTokenRestServlet(RestServlet):
+class EmailRegisterRequestTokenRestServlet(RestServlet):
     PATTERNS = client_v2_patterns("/register/email/requestToken$")
 
     def __init__(self, hs):
@@ -51,7 +60,7 @@ class RegisterRequestTokenRestServlet(RestServlet):
         Args:
             hs (synapse.server.HomeServer): server
         """
-        super(RegisterRequestTokenRestServlet, self).__init__()
+        super(EmailRegisterRequestTokenRestServlet, self).__init__()
         self.hs = hs
         self.identity_handler = hs.get_handlers().identity_handler
 
@@ -59,14 +68,14 @@ class RegisterRequestTokenRestServlet(RestServlet):
     def on_POST(self, request):
         body = parse_json_object_from_request(request)
 
-        required = ['id_server', 'client_secret', 'email', 'send_attempt']
-        absent = []
-        for k in required:
-            if k not in body:
-                absent.append(k)
+        assert_params_in_request(body, [
+            'id_server', 'client_secret', 'email', 'send_attempt'
+        ])
 
-        if len(absent) > 0:
-            raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+        if not check_3pid_allowed(self.hs, "email", body['email']):
+            raise SynapseError(
+                403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+            )
 
         existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
             'email', body['email']
@@ -79,6 +88,86 @@ class RegisterRequestTokenRestServlet(RestServlet):
         defer.returnValue((200, ret))
 
 
+class MsisdnRegisterRequestTokenRestServlet(RestServlet):
+    PATTERNS = client_v2_patterns("/register/msisdn/requestToken$")
+
+    def __init__(self, hs):
+        """
+        Args:
+            hs (synapse.server.HomeServer): server
+        """
+        super(MsisdnRegisterRequestTokenRestServlet, self).__init__()
+        self.hs = hs
+        self.identity_handler = hs.get_handlers().identity_handler
+
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        body = parse_json_object_from_request(request)
+
+        assert_params_in_request(body, [
+            'id_server', 'client_secret',
+            'country', 'phone_number',
+            'send_attempt',
+        ])
+
+        msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+
+        if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+            raise SynapseError(
+                403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+            )
+
+        existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
+            'msisdn', msisdn
+        )
+
+        if existingUid is not None:
+            raise SynapseError(
+                400, "Phone number is already in use", Codes.THREEPID_IN_USE
+            )
+
+        ret = yield self.identity_handler.requestMsisdnToken(**body)
+        defer.returnValue((200, ret))
+
+
+class UsernameAvailabilityRestServlet(RestServlet):
+    PATTERNS = client_v2_patterns("/register/available")
+
+    def __init__(self, hs):
+        """
+        Args:
+            hs (synapse.server.HomeServer): server
+        """
+        super(UsernameAvailabilityRestServlet, self).__init__()
+        self.hs = hs
+        self.registration_handler = hs.get_handlers().registration_handler
+        self.ratelimiter = FederationRateLimiter(
+            hs.get_clock(),
+            # Time window of 2s
+            window_size=2000,
+            # Artificially delay requests if rate > sleep_limit/window_size
+            sleep_limit=1,
+            # Amount of artificial delay to apply
+            sleep_msec=1000,
+            # Error with 429 if more than reject_limit requests are queued
+            reject_limit=1,
+            # Allow 1 request at a time
+            concurrent_requests=1,
+        )
+
+    @defer.inlineCallbacks
+    def on_GET(self, request):
+        ip = self.hs.get_ip_from_request(request)
+        with self.ratelimiter.ratelimit(ip) as wait_deferred:
+            yield wait_deferred
+
+            username = parse_string(request, "username", required=True)
+
+            yield self.registration_handler.check_username(username)
+
+            defer.returnValue((200, {"available": True}))
+
+
 class RegisterRestServlet(RestServlet):
     PATTERNS = client_v2_patterns("/register$")
 
@@ -95,9 +184,11 @@ class RegisterRestServlet(RestServlet):
         self.auth_handler = hs.get_auth_handler()
         self.registration_handler = hs.get_handlers().registration_handler
         self.identity_handler = hs.get_handlers().identity_handler
+        self.room_member_handler = hs.get_room_member_handler()
         self.device_handler = hs.get_device_handler()
         self.macaroon_gen = hs.get_macaroon_generator()
 
+    @interactive_auth_handler
     @defer.inlineCallbacks
     def on_POST(self, request):
         yield run_on_reactor()
@@ -121,14 +212,14 @@ class RegisterRestServlet(RestServlet):
         # in sessions. Pull out the username/password provided to us.
         desired_password = None
         if 'password' in body:
-            if (not isinstance(body['password'], basestring) or
+            if (not isinstance(body['password'], string_types) or
                     len(body['password']) > 512):
                 raise SynapseError(400, "Invalid password")
             desired_password = body["password"]
 
         desired_username = None
         if 'username' in body:
-            if (not isinstance(body['username'], basestring) or
+            if (not isinstance(body['username'], string_types) or
                     len(body['username']) > 512):
                 raise SynapseError(400, "Invalid username")
             desired_username = body['username']
@@ -146,15 +237,30 @@ class RegisterRestServlet(RestServlet):
             # 'user' key not 'username'). Since this is a new addition, we'll
             # fallback to 'username' if they gave one.
             desired_username = body.get("user", desired_username)
+
+            # XXX we should check that desired_username is valid. Currently
+            # we give appservices carte blanche for any insanity in mxids,
+            # because the IRC bridges rely on being able to register stupid
+            # IDs.
+
             access_token = get_access_token_from_request(request)
 
-            if isinstance(desired_username, basestring):
+            if isinstance(desired_username, string_types):
                 result = yield self._do_appservice_registration(
                     desired_username, access_token, body
                 )
             defer.returnValue((200, result))  # we throw for non 200 responses
             return
 
+        # for either shared secret or regular registration, downcase the
+        # provided username before attempting to register it. This should mean
+        # that people who try to register with upper-case in their usernames
+        # don't get a nasty surprise. (Note that we treat username
+        # case-insenstively in login, so they are free to carry on imagining
+        # that their username is CrAzYh4cKeR if that keeps them happy)
+        if desired_username is not None:
+            desired_username = desired_username.lower()
+
         # == Shared Secret Registration == (e.g. create new user scripts)
         if 'mac' in body:
             # FIXME: Should we really be determining if this is shared secret
@@ -200,32 +306,86 @@ class RegisterRestServlet(RestServlet):
                 assigned_user_id=registered_user_id,
             )
 
+        # Only give msisdn flows if the x_show_msisdn flag is given:
+        # this is a hack to work around the fact that clients were shipped
+        # that use fallback registration if they see any flows that they don't
+        # recognise, which means we break registration for these clients if we
+        # advertise msisdn flows. Once usage of Riot iOS <=0.3.9 and Riot
+        # Android <=0.6.9 have fallen below an acceptable threshold, this
+        # parameter should go away and we should always advertise msisdn flows.
+        show_msisdn = False
+        if 'x_show_msisdn' in body and body['x_show_msisdn']:
+            show_msisdn = True
+
+        # FIXME: need a better error than "no auth flow found" for scenarios
+        # where we required 3PID for registration but the user didn't give one
+        require_email = 'email' in self.hs.config.registrations_require_3pid
+        require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
+
+        flows = []
         if self.hs.config.enable_registration_captcha:
-            flows = [
-                [LoginType.RECAPTCHA],
-                [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]
-            ]
+            # only support 3PIDless registration if no 3PIDs are required
+            if not require_email and not require_msisdn:
+                flows.extend([[LoginType.RECAPTCHA]])
+            # only support the email-only flow if we don't require MSISDN 3PIDs
+            if not require_msisdn:
+                flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])
+
+            if show_msisdn:
+                # only support the MSISDN-only flow if we don't require email 3PIDs
+                if not require_email:
+                    flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]])
+                # always let users provide both MSISDN & email
+                flows.extend([
+                    [LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
+                ])
         else:
-            flows = [
-                [LoginType.DUMMY],
-                [LoginType.EMAIL_IDENTITY]
-            ]
-
-        authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
+            # only support 3PIDless registration if no 3PIDs are required
+            if not require_email and not require_msisdn:
+                flows.extend([[LoginType.DUMMY]])
+            # only support the email-only flow if we don't require MSISDN 3PIDs
+            if not require_msisdn:
+                flows.extend([[LoginType.EMAIL_IDENTITY]])
+
+            if show_msisdn:
+                # only support the MSISDN-only flow if we don't require email 3PIDs
+                if not require_email or require_msisdn:
+                    flows.extend([[LoginType.MSISDN]])
+                # always let users provide both MSISDN & email
+                flows.extend([
+                    [LoginType.MSISDN, LoginType.EMAIL_IDENTITY]
+                ])
+
+        auth_result, params, session_id = yield self.auth_handler.check_auth(
             flows, body, self.hs.get_ip_from_request(request)
         )
 
-        if not authed:
-            defer.returnValue((401, auth_result))
-            return
+        # Check that we're not trying to register a denied 3pid.
+        #
+        # the user-facing checks will probably already have happened in
+        # /register/email/requestToken when we requested a 3pid, but that's not
+        # guaranteed.
+
+        if auth_result:
+            for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
+                if login_type in auth_result:
+                    medium = auth_result[login_type]['medium']
+                    address = auth_result[login_type]['address']
+
+                    if not check_3pid_allowed(self.hs, medium, address):
+                        raise SynapseError(
+                            403, "Third party identifier is not allowed",
+                            Codes.THREEPID_DENIED,
+                        )
 
         if registered_user_id is not None:
             logger.info(
                 "Already registered user ID %r for this session",
                 registered_user_id
             )
-            # don't re-register the email address
+            # don't re-register the threepids
             add_email = False
+            add_msisdn = False
         else:
             # NB: This may be from the auth handler and NOT from the POST
             if 'password' not in params:
@@ -236,6 +396,9 @@ class RegisterRestServlet(RestServlet):
             new_password = params.get("password", None)
             guest_access_token = params.get("guest_access_token", None)
 
+            if desired_username is not None:
+                desired_username = desired_username.lower()
+
             (registered_user_id, _) = yield self.registration_handler.register(
                 localpart=desired_username,
                 password=new_password,
@@ -250,6 +413,7 @@ class RegisterRestServlet(RestServlet):
             )
 
             add_email = True
+            add_msisdn = True
 
         return_dict = yield self._create_registration_details(
             registered_user_id, params
@@ -262,6 +426,13 @@ class RegisterRestServlet(RestServlet):
                 params.get("bind_email")
             )
 
+        if add_msisdn and auth_result and LoginType.MSISDN in auth_result:
+            threepid = auth_result[LoginType.MSISDN]
+            yield self._register_msisdn_threepid(
+                registered_user_id, threepid, return_dict["access_token"],
+                params.get("bind_msisdn")
+            )
+
         defer.returnValue((200, return_dict))
 
     def on_OPTIONS(self, _):
@@ -278,15 +449,24 @@ class RegisterRestServlet(RestServlet):
     def _do_shared_secret_registration(self, username, password, body):
         if not self.hs.config.registration_shared_secret:
             raise SynapseError(400, "Shared secret registration is not enabled")
+        if not username:
+            raise SynapseError(
+                400, "username must be specified", errcode=Codes.BAD_JSON,
+            )
 
-        user = username.encode("utf-8")
+        # use the username from the original request rather than the
+        # downcased one in `username` for the mac calculation
+        user = body["username"].encode("utf-8")
 
         # str() because otherwise hmac complains that 'unicode' does not
         # have the buffer interface
         got_mac = str(body["mac"])
 
+        # FIXME this is different to the /v1/register endpoint, which
+        # includes the password and admin flag in the hashed text. Why are
+        # these different?
         want_mac = hmac.new(
-            key=self.hs.config.registration_shared_secret,
+            key=self.hs.config.registration_shared_secret.encode(),
             msg=user,
             digestmod=sha1,
         ).hexdigest()
@@ -323,8 +503,9 @@ class RegisterRestServlet(RestServlet):
         """
         reqd = ('medium', 'address', 'validated_at')
         if any(x not in threepid for x in reqd):
+            # This will only happen if the ID server returns a malformed response
             logger.info("Can't add incomplete 3pid")
-            defer.returnValue()
+            return
 
         yield self.auth_handler.add_threepid(
             user_id,
@@ -372,6 +553,43 @@ class RegisterRestServlet(RestServlet):
             logger.info("bind_email not specified: not binding email")
 
     @defer.inlineCallbacks
+    def _register_msisdn_threepid(self, user_id, threepid, token, bind_msisdn):
+        """Add a phone number as a 3pid identifier
+
+        Also optionally binds msisdn to the given user_id on the identity server
+
+        Args:
+            user_id (str): id of user
+            threepid (object): m.login.msisdn auth response
+            token (str): access_token for the user
+            bind_email (bool): true if the client requested the email to be
+                bound at the identity server
+        Returns:
+            defer.Deferred:
+        """
+        reqd = ('medium', 'address', 'validated_at')
+        if any(x not in threepid for x in reqd):
+            # This will only happen if the ID server returns a malformed response
+            logger.info("Can't add incomplete 3pid")
+            defer.returnValue()
+
+        yield self.auth_handler.add_threepid(
+            user_id,
+            threepid['medium'],
+            threepid['address'],
+            threepid['validated_at'],
+        )
+
+        if bind_msisdn:
+            logger.info("bind_msisdn specified: binding")
+            logger.debug("Binding msisdn %s to %s", threepid, user_id)
+            yield self.identity_handler.bind_threepid(
+                threepid['threepid_creds'], user_id
+            )
+        else:
+            logger.info("bind_msisdn not specified: not binding msisdn")
+
+    @defer.inlineCallbacks
     def _create_registration_details(self, user_id, params):
         """Complete registration of newly-registered user
 
@@ -380,25 +598,28 @@ class RegisterRestServlet(RestServlet):
         Args:
             (str) user_id: full canonical @user:id
             (object) params: registration parameters, from which we pull
-                device_id and initial_device_name
+                device_id, initial_device_name and inhibit_login
         Returns:
             defer.Deferred: (object) dictionary for response from /register
         """
-        device_id = yield self._register_device(user_id, params)
+        result = {
+            "user_id": user_id,
+            "home_server": self.hs.hostname,
+        }
+        if not params.get("inhibit_login", False):
+            device_id = yield self._register_device(user_id, params)
 
-        access_token = (
-            yield self.auth_handler.get_access_token_for_user_id(
-                user_id, device_id=device_id,
-                initial_display_name=params.get("initial_device_display_name")
+            access_token = (
+                yield self.auth_handler.get_access_token_for_user_id(
+                    user_id, device_id=device_id,
+                )
             )
-        )
 
-        defer.returnValue({
-            "user_id": user_id,
-            "access_token": access_token,
-            "home_server": self.hs.hostname,
-            "device_id": device_id,
-        })
+            result.update({
+                "access_token": access_token,
+                "device_id": device_id,
+            })
+        defer.returnValue(result)
 
     def _register_device(self, user_id, params):
         """Register a device for a user.
@@ -433,7 +654,7 @@ class RegisterRestServlet(RestServlet):
         # we have nowhere to store it.
         device_id = synapse.api.auth.GUEST_DEVICE_ID
         initial_display_name = params.get("initial_device_display_name")
-        self.device_handler.check_device_registered(
+        yield self.device_handler.check_device_registered(
             user_id, device_id, initial_display_name
         )
 
@@ -449,5 +670,7 @@ class RegisterRestServlet(RestServlet):
 
 
 def register_servlets(hs, http_server):
-    RegisterRequestTokenRestServlet(hs).register(http_server)
+    EmailRegisterRequestTokenRestServlet(hs).register(http_server)
+    MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
+    UsernameAvailabilityRestServlet(hs).register(http_server)
     RegisterRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index d607bd2970..90bdb1db15 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
 class SendToDeviceRestServlet(servlet.RestServlet):
     PATTERNS = client_v2_patterns(
         "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
-        releases=[], v2_alpha=False
+        v2_alpha=False
     )
 
     def __init__(self, hs):
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index b3d8001638..eb91c0b293 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
 from synapse.http.servlet import (
     RestServlet, parse_string, parse_integer, parse_boolean
 )
+from synapse.handlers.presence import format_user_presence_state
 from synapse.handlers.sync import SyncConfig
 from synapse.types import StreamToken
 from synapse.events.utils import (
@@ -27,12 +28,12 @@ from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION
 from synapse.api.errors import SynapseError
 from synapse.api.constants import PresenceState
 from ._base import client_v2_patterns
+from ._base import set_timeline_upper_limit
 
-import copy
 import itertools
 import logging
 
-import ujson as json
+import simplejson as json
 
 logger = logging.getLogger(__name__)
 
@@ -78,6 +79,7 @@ class SyncRestServlet(RestServlet):
 
     def __init__(self, hs):
         super(SyncRestServlet, self).__init__()
+        self.hs = hs
         self.auth = hs.get_auth()
         self.sync_handler = hs.get_sync_handler()
         self.clock = hs.get_clock()
@@ -108,7 +110,7 @@ class SyncRestServlet(RestServlet):
         filter_id = parse_string(request, "filter", default=None)
         full_state = parse_boolean(request, "full_state", default=False)
 
-        logger.info(
+        logger.debug(
             "/sync: user=%r, timeout=%r, since=%r,"
             " set_presence=%r, filter_id=%r, device_id=%r" % (
                 user, timeout, since, set_presence, filter_id, device_id
@@ -121,7 +123,9 @@ class SyncRestServlet(RestServlet):
             if filter_id.startswith('{'):
                 try:
                     filter_object = json.loads(filter_id)
-                except:
+                    set_timeline_upper_limit(filter_object,
+                                             self.hs.config.filter_timeline_limit)
+                except Exception:
                     raise SynapseError(400, "Invalid filter JSON")
                 self.filtering.check_valid_filter(filter_object)
                 filter = FilterCollection(filter_object)
@@ -160,27 +164,35 @@ class SyncRestServlet(RestServlet):
             )
 
         time_now = self.clock.time_msec()
+        response_content = self.encode_response(
+            time_now, sync_result, requester.access_token_id, filter
+        )
+
+        defer.returnValue((200, response_content))
 
-        joined = self.encode_joined(
-            sync_result.joined, time_now, requester.access_token_id, filter.event_fields
+    @staticmethod
+    def encode_response(time_now, sync_result, access_token_id, filter):
+        joined = SyncRestServlet.encode_joined(
+            sync_result.joined, time_now, access_token_id, filter.event_fields
         )
 
-        invited = self.encode_invited(
-            sync_result.invited, time_now, requester.access_token_id
+        invited = SyncRestServlet.encode_invited(
+            sync_result.invited, time_now, access_token_id,
         )
 
-        archived = self.encode_archived(
-            sync_result.archived, time_now, requester.access_token_id,
+        archived = SyncRestServlet.encode_archived(
+            sync_result.archived, time_now, access_token_id,
             filter.event_fields,
         )
 
-        response_content = {
+        return {
             "account_data": {"events": sync_result.account_data},
             "to_device": {"events": sync_result.to_device},
             "device_lists": {
-                "changed": list(sync_result.device_lists),
+                "changed": list(sync_result.device_lists.changed),
+                "left": list(sync_result.device_lists.left),
             },
-            "presence": self.encode_presence(
+            "presence": SyncRestServlet.encode_presence(
                 sync_result.presence, time_now
             ),
             "rooms": {
@@ -188,20 +200,32 @@ class SyncRestServlet(RestServlet):
                 "invite": invited,
                 "leave": archived,
             },
+            "groups": {
+                "join": sync_result.groups.join,
+                "invite": sync_result.groups.invite,
+                "leave": sync_result.groups.leave,
+            },
+            "device_one_time_keys_count": sync_result.device_one_time_keys_count,
             "next_batch": sync_result.next_batch.to_string(),
         }
 
-        defer.returnValue((200, response_content))
-
-    def encode_presence(self, events, time_now):
-        formatted = []
-        for event in events:
-            event = copy.deepcopy(event)
-            event['sender'] = event['content'].pop('user_id')
-            formatted.append(event)
-        return {"events": formatted}
+    @staticmethod
+    def encode_presence(events, time_now):
+        return {
+            "events": [
+                {
+                    "type": "m.presence",
+                    "sender": event.user_id,
+                    "content": format_user_presence_state(
+                        event, time_now, include_user_id=False
+                    ),
+                }
+                for event in events
+            ]
+        }
 
-    def encode_joined(self, rooms, time_now, token_id, event_fields):
+    @staticmethod
+    def encode_joined(rooms, time_now, token_id, event_fields):
         """
         Encode the joined rooms in a sync result
 
@@ -220,13 +244,14 @@ class SyncRestServlet(RestServlet):
         """
         joined = {}
         for room in rooms:
-            joined[room.room_id] = self.encode_room(
+            joined[room.room_id] = SyncRestServlet.encode_room(
                 room, time_now, token_id, only_fields=event_fields
             )
 
         return joined
 
-    def encode_invited(self, rooms, time_now, token_id):
+    @staticmethod
+    def encode_invited(rooms, time_now, token_id):
         """
         Encode the invited rooms in a sync result
 
@@ -247,6 +272,7 @@ class SyncRestServlet(RestServlet):
             invite = serialize_event(
                 room.invite, time_now, token_id=token_id,
                 event_format=format_event_for_client_v2_without_room_id,
+                is_invite=True,
             )
             unsigned = dict(invite.get("unsigned", {}))
             invite["unsigned"] = unsigned
@@ -258,7 +284,8 @@ class SyncRestServlet(RestServlet):
 
         return invited
 
-    def encode_archived(self, rooms, time_now, token_id, event_fields):
+    @staticmethod
+    def encode_archived(rooms, time_now, token_id, event_fields):
         """
         Encode the archived rooms in a sync result
 
@@ -277,7 +304,7 @@ class SyncRestServlet(RestServlet):
         """
         joined = {}
         for room in rooms:
-            joined[room.room_id] = self.encode_room(
+            joined[room.room_id] = SyncRestServlet.encode_room(
                 room, time_now, token_id, joined=False, only_fields=event_fields
             )
 
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index 31f94bc6e9..6773b9ba60 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
 
 
 class ThirdPartyProtocolsServlet(RestServlet):
-    PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=())
+    PATTERNS = client_v2_patterns("/thirdparty/protocols")
 
     def __init__(self, hs):
         super(ThirdPartyProtocolsServlet, self).__init__()
@@ -36,15 +36,14 @@ class ThirdPartyProtocolsServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request):
-        yield self.auth.get_user_by_req(request)
+        yield self.auth.get_user_by_req(request, allow_guest=True)
 
         protocols = yield self.appservice_handler.get_3pe_protocols()
         defer.returnValue((200, protocols))
 
 
 class ThirdPartyProtocolServlet(RestServlet):
-    PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$",
-                                  releases=())
+    PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
 
     def __init__(self, hs):
         super(ThirdPartyProtocolServlet, self).__init__()
@@ -54,7 +53,7 @@ class ThirdPartyProtocolServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request, protocol):
-        yield self.auth.get_user_by_req(request)
+        yield self.auth.get_user_by_req(request, allow_guest=True)
 
         protocols = yield self.appservice_handler.get_3pe_protocols(
             only_protocol=protocol,
@@ -66,8 +65,7 @@ class ThirdPartyProtocolServlet(RestServlet):
 
 
 class ThirdPartyUserServlet(RestServlet):
-    PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
-                                  releases=())
+    PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
 
     def __init__(self, hs):
         super(ThirdPartyUserServlet, self).__init__()
@@ -77,7 +75,7 @@ class ThirdPartyUserServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request, protocol):
-        yield self.auth.get_user_by_req(request)
+        yield self.auth.get_user_by_req(request, allow_guest=True)
 
         fields = request.args
         fields.pop("access_token", None)
@@ -90,8 +88,7 @@ class ThirdPartyUserServlet(RestServlet):
 
 
 class ThirdPartyLocationServlet(RestServlet):
-    PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$",
-                                  releases=())
+    PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
 
     def __init__(self, hs):
         super(ThirdPartyLocationServlet, self).__init__()
@@ -101,7 +98,7 @@ class ThirdPartyLocationServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, request, protocol):
-        yield self.auth.get_user_by_req(request)
+        yield self.auth.get_user_by_req(request, allow_guest=True)
 
         fields = request.args
         fields.pop("access_token", None)
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
new file mode 100644
index 0000000000..2d4a43c353
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -0,0 +1,79 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from ._base import client_v2_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class UserDirectorySearchRestServlet(RestServlet):
+    PATTERNS = client_v2_patterns("/user_directory/search$")
+
+    def __init__(self, hs):
+        """
+        Args:
+            hs (synapse.server.HomeServer): server
+        """
+        super(UserDirectorySearchRestServlet, self).__init__()
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.user_directory_handler = hs.get_user_directory_handler()
+
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        """Searches for users in directory
+
+        Returns:
+            dict of the form::
+
+                {
+                    "limited": <bool>,  # whether there were more results or not
+                    "results": [  # Ordered by best match first
+                        {
+                            "user_id": <user_id>,
+                            "display_name": <display_name>,
+                            "avatar_url": <avatar_url>
+                        }
+                    ]
+                }
+        """
+        requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+        user_id = requester.user.to_string()
+
+        body = parse_json_object_from_request(request)
+
+        limit = body.get("limit", 10)
+        limit = min(limit, 50)
+
+        try:
+            search_term = body["search_term"]
+        except Exception:
+            raise SynapseError(400, "`search_term` is required field")
+
+        results = yield self.user_directory_handler.search_users(
+            user_id, search_term, limit,
+        )
+
+        defer.returnValue((200, results))
+
+
+def register_servlets(hs, http_server):
+    UserDirectorySearchRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index e984ea47db..2ecb15deee 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -30,6 +30,7 @@ class VersionsRestServlet(RestServlet):
                 "r0.0.1",
                 "r0.1.0",
                 "r0.2.0",
+                "r0.3.0",
             ]
         })
 
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index ff95269ba8..be68d9a096 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -84,12 +84,11 @@ class LocalKey(Resource):
             }
 
         old_verify_keys = {}
-        for key in self.config.old_signing_keys:
-            key_id = "%s:%s" % (key.alg, key.version)
+        for key_id, key in self.config.old_signing_keys.items():
             verify_key_bytes = key.encode()
             old_verify_keys[key_id] = {
                 u"key": encode_base64(verify_key_bytes),
-                u"expired_ts": key.expired,
+                u"expired_ts": key.expired_ts,
             }
 
         tls_fingerprints = self.config.tls_fingerprints
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 9fe2013657..17e6079cba 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -93,6 +93,7 @@ class RemoteKey(Resource):
         self.store = hs.get_datastore()
         self.version_string = hs.version_string
         self.clock = hs.get_clock()
+        self.federation_domain_whitelist = hs.config.federation_domain_whitelist
 
     def render_GET(self, request):
         self.async_render_GET(request)
@@ -137,6 +138,13 @@ class RemoteKey(Resource):
         logger.info("Handling query for keys %r", query)
         store_queries = []
         for server_name, key_ids in query.items():
+            if (
+                self.federation_domain_whitelist is not None and
+                server_name not in self.federation_domain_whitelist
+            ):
+                logger.debug("Federation denied with %s", server_name)
+                continue
+
             if not key_ids:
                 key_ids = (None,)
             for key_id in key_ids:
@@ -213,7 +221,7 @@ class RemoteKey(Resource):
                     )
                 except KeyLookupError as e:
                     logger.info("Failed to fetch key: %s", e)
-                except:
+                except Exception:
                     logger.exception("Failed to get key for %r", server_name)
             yield self.query_keys(
                 request, query, query_remote_on_cache_miss=False
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index b9600f2167..c0d2f06855 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -17,6 +17,7 @@ from synapse.http.server import respond_with_json, finish_request
 from synapse.api.errors import (
     cs_error, Codes, SynapseError
 )
+from synapse.util import logcontext
 
 from twisted.internet import defer
 from twisted.protocols.basic import FileSender
@@ -27,7 +28,7 @@ import os
 
 import logging
 import urllib
-import urlparse
+from six.moves.urllib import parse as urlparse
 
 logger = logging.getLogger(__name__)
 
@@ -44,7 +45,7 @@ def parse_media_id(request):
             except UnicodeDecodeError:
                 pass
         return server_name, media_id, file_name
-    except:
+    except Exception:
         raise SynapseError(
             404,
             "Invalid media id token %r" % (request.postpath,),
@@ -69,42 +70,133 @@ def respond_with_file(request, media_type, file_path,
     logger.debug("Responding with %r", file_path)
 
     if os.path.isfile(file_path):
-        request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
-        if upload_name:
-            if is_ascii(upload_name):
-                request.setHeader(
-                    b"Content-Disposition",
-                    b"inline; filename=%s" % (
-                        urllib.quote(upload_name.encode("utf-8")),
-                    ),
-                )
-            else:
-                request.setHeader(
-                    b"Content-Disposition",
-                    b"inline; filename*=utf-8''%s" % (
-                        urllib.quote(upload_name.encode("utf-8")),
-                    ),
-                )
-
-        # cache for at least a day.
-        # XXX: we might want to turn this off for data we don't want to
-        # recommend caching as it's sensitive or private - or at least
-        # select private. don't bother setting Expires as all our
-        # clients are smart enough to be happy with Cache-Control
-        request.setHeader(
-            b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
-        )
         if file_size is None:
             stat = os.stat(file_path)
             file_size = stat.st_size
 
-        request.setHeader(
-            b"Content-Length", b"%d" % (file_size,)
-        )
+        add_file_headers(request, media_type, file_size, upload_name)
 
         with open(file_path, "rb") as f:
-            yield FileSender().beginFileTransfer(f, request)
+            yield logcontext.make_deferred_yieldable(
+                FileSender().beginFileTransfer(f, request)
+            )
 
         finish_request(request)
     else:
         respond_404(request)
+
+
+def add_file_headers(request, media_type, file_size, upload_name):
+    """Adds the correct response headers in preparation for responding with the
+    media.
+
+    Args:
+        request (twisted.web.http.Request)
+        media_type (str): The media/content type.
+        file_size (int): Size in bytes of the media, if known.
+        upload_name (str): The name of the requested file, if any.
+    """
+    request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
+    if upload_name:
+        if is_ascii(upload_name):
+            request.setHeader(
+                b"Content-Disposition",
+                b"inline; filename=%s" % (
+                    urllib.quote(upload_name.encode("utf-8")),
+                ),
+            )
+        else:
+            request.setHeader(
+                b"Content-Disposition",
+                b"inline; filename*=utf-8''%s" % (
+                    urllib.quote(upload_name.encode("utf-8")),
+                ),
+            )
+
+    # cache for at least a day.
+    # XXX: we might want to turn this off for data we don't want to
+    # recommend caching as it's sensitive or private - or at least
+    # select private. don't bother setting Expires as all our
+    # clients are smart enough to be happy with Cache-Control
+    request.setHeader(
+        b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
+    )
+
+    request.setHeader(
+        b"Content-Length", b"%d" % (file_size,)
+    )
+
+
+@defer.inlineCallbacks
+def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
+    """Responds to the request with given responder. If responder is None then
+    returns 404.
+
+    Args:
+        request (twisted.web.http.Request)
+        responder (Responder|None)
+        media_type (str): The media/content type.
+        file_size (int|None): Size in bytes of the media. If not known it should be None
+        upload_name (str|None): The name of the requested file, if any.
+    """
+    if not responder:
+        respond_404(request)
+        return
+
+    logger.debug("Responding to media request with responder %s")
+    add_file_headers(request, media_type, file_size, upload_name)
+    with responder:
+        yield responder.write_to_consumer(request)
+    finish_request(request)
+
+
+class Responder(object):
+    """Represents a response that can be streamed to the requester.
+
+    Responder is a context manager which *must* be used, so that any resources
+    held can be cleaned up.
+    """
+    def write_to_consumer(self, consumer):
+        """Stream response into consumer
+
+        Args:
+            consumer (IConsumer)
+
+        Returns:
+            Deferred: Resolves once the response has finished being written
+        """
+        pass
+
+    def __enter__(self):
+        pass
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        pass
+
+
+class FileInfo(object):
+    """Details about a requested/uploaded file.
+
+    Attributes:
+        server_name (str): The server name where the media originated from,
+            or None if local.
+        file_id (str): The local ID of the file. For local files this is the
+            same as the media_id
+        url_cache (bool): If the file is for the url preview cache
+        thumbnail (bool): Whether the file is a thumbnail or not.
+        thumbnail_width (int)
+        thumbnail_height (int)
+        thumbnail_method (str)
+        thumbnail_type (str): Content type of thumbnail, e.g. image/png
+    """
+    def __init__(self, server_name, file_id, url_cache=False,
+                 thumbnail=False, thumbnail_width=None, thumbnail_height=None,
+                 thumbnail_method=None, thumbnail_type=None):
+        self.server_name = server_name
+        self.file_id = file_id
+        self.url_cache = url_cache
+        self.thumbnail = thumbnail
+        self.thumbnail_width = thumbnail_width
+        self.thumbnail_height = thumbnail_height
+        self.thumbnail_method = thumbnail_method
+        self.thumbnail_type = thumbnail_type
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index dfb87ffd15..fe7e17596f 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -12,8 +12,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import synapse.http.servlet
 
-from ._base import parse_media_id, respond_with_file, respond_404
+from ._base import parse_media_id, respond_404
 from twisted.web.resource import Resource
 from synapse.http.server import request_handler, set_cors_headers
 
@@ -31,12 +32,12 @@ class DownloadResource(Resource):
     def __init__(self, hs, media_repo):
         Resource.__init__(self)
 
-        self.filepaths = media_repo.filepaths
         self.media_repo = media_repo
         self.server_name = hs.hostname
-        self.store = hs.get_datastore()
-        self.version_string = hs.version_string
+
+        # Both of these are expected by @request_handler()
         self.clock = hs.get_clock()
+        self.version_string = hs.version_string
 
     def render_GET(self, request):
         self._async_render_GET(request)
@@ -56,43 +57,16 @@ class DownloadResource(Resource):
         )
         server_name, media_id, name = parse_media_id(request)
         if server_name == self.server_name:
-            yield self._respond_local_file(request, media_id, name)
+            yield self.media_repo.get_local_media(request, media_id, name)
         else:
-            yield self._respond_remote_file(
-                request, server_name, media_id, name
-            )
-
-    @defer.inlineCallbacks
-    def _respond_local_file(self, request, media_id, name):
-        media_info = yield self.store.get_local_media(media_id)
-        if not media_info:
-            respond_404(request)
-            return
-
-        media_type = media_info["media_type"]
-        media_length = media_info["media_length"]
-        upload_name = name if name else media_info["upload_name"]
-        file_path = self.filepaths.local_media_filepath(media_id)
-
-        yield respond_with_file(
-            request, media_type, file_path, media_length,
-            upload_name=upload_name,
-        )
-
-    @defer.inlineCallbacks
-    def _respond_remote_file(self, request, server_name, media_id, name):
-        media_info = yield self.media_repo.get_remote_media(server_name, media_id)
-
-        media_type = media_info["media_type"]
-        media_length = media_info["media_length"]
-        filesystem_id = media_info["filesystem_id"]
-        upload_name = name if name else media_info["upload_name"]
-
-        file_path = self.filepaths.remote_media_filepath(
-            server_name, filesystem_id
-        )
-
-        yield respond_with_file(
-            request, media_type, file_path, media_length,
-            upload_name=upload_name,
-        )
+            allow_remote = synapse.http.servlet.parse_boolean(
+                request, "allow_remote", default=True)
+            if not allow_remote:
+                logger.info(
+                    "Rejecting request for remote media %s/%s due to allow_remote",
+                    server_name, media_id,
+                )
+                respond_404(request)
+                return
+
+            yield self.media_repo.get_remote_media(request, server_name, media_id, name)
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 0137458f71..d5164e47e0 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -14,60 +14,200 @@
 # limitations under the License.
 
 import os
+import re
+import functools
+
+NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
+
+
+def _wrap_in_base_path(func):
+    """Takes a function that returns a relative path and turns it into an
+    absolute path based on the location of the primary media store
+    """
+    @functools.wraps(func)
+    def _wrapped(self, *args, **kwargs):
+        path = func(self, *args, **kwargs)
+        return os.path.join(self.base_path, path)
+
+    return _wrapped
 
 
 class MediaFilePaths(object):
+    """Describes where files are stored on disk.
 
-    def __init__(self, base_path):
-        self.base_path = base_path
+    Most of the functions have a `*_rel` variant which returns a file path that
+    is relative to the base media store path. This is mainly used when we want
+    to write to the backup media store (when one is configured)
+    """
 
-    def default_thumbnail(self, default_top_level, default_sub_type, width,
-                          height, content_type, method):
+    def __init__(self, primary_base_path):
+        self.base_path = primary_base_path
+
+    def default_thumbnail_rel(self, default_top_level, default_sub_type, width,
+                              height, content_type, method):
         top_level_type, sub_type = content_type.split("/")
         file_name = "%i-%i-%s-%s-%s" % (
             width, height, top_level_type, sub_type, method
         )
         return os.path.join(
-            self.base_path, "default_thumbnails", default_top_level,
+            "default_thumbnails", default_top_level,
             default_sub_type, file_name
         )
 
-    def local_media_filepath(self, media_id):
+    default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
+
+    def local_media_filepath_rel(self, media_id):
         return os.path.join(
-            self.base_path, "local_content",
+            "local_content",
             media_id[0:2], media_id[2:4], media_id[4:]
         )
 
-    def local_media_thumbnail(self, media_id, width, height, content_type,
-                              method):
+    local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
+
+    def local_media_thumbnail_rel(self, media_id, width, height, content_type,
+                                  method):
         top_level_type, sub_type = content_type.split("/")
         file_name = "%i-%i-%s-%s-%s" % (
             width, height, top_level_type, sub_type, method
         )
         return os.path.join(
-            self.base_path, "local_thumbnails",
+            "local_thumbnails",
             media_id[0:2], media_id[2:4], media_id[4:],
             file_name
         )
 
-    def remote_media_filepath(self, server_name, file_id):
+    local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel)
+
+    def remote_media_filepath_rel(self, server_name, file_id):
         return os.path.join(
-            self.base_path, "remote_content", server_name,
+            "remote_content", server_name,
             file_id[0:2], file_id[2:4], file_id[4:]
         )
 
-    def remote_media_thumbnail(self, server_name, file_id, width, height,
-                               content_type, method):
+    remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
+
+    def remote_media_thumbnail_rel(self, server_name, file_id, width, height,
+                                   content_type, method):
         top_level_type, sub_type = content_type.split("/")
         file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
         return os.path.join(
-            self.base_path, "remote_thumbnail", server_name,
+            "remote_thumbnail", server_name,
             file_id[0:2], file_id[2:4], file_id[4:],
             file_name
         )
 
+    remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
+
     def remote_media_thumbnail_dir(self, server_name, file_id):
         return os.path.join(
             self.base_path, "remote_thumbnail", server_name,
             file_id[0:2], file_id[2:4], file_id[4:],
         )
+
+    def url_cache_filepath_rel(self, media_id):
+        if NEW_FORMAT_ID_RE.match(media_id):
+            # Media id is of the form <DATE><RANDOM_STRING>
+            # E.g.: 2017-09-28-fsdRDt24DS234dsf
+            return os.path.join(
+                "url_cache",
+                media_id[:10], media_id[11:]
+            )
+        else:
+            return os.path.join(
+                "url_cache",
+                media_id[0:2], media_id[2:4], media_id[4:],
+            )
+
+    url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
+
+    def url_cache_filepath_dirs_to_delete(self, media_id):
+        "The dirs to try and remove if we delete the media_id file"
+        if NEW_FORMAT_ID_RE.match(media_id):
+            return [
+                os.path.join(
+                    self.base_path, "url_cache",
+                    media_id[:10],
+                ),
+            ]
+        else:
+            return [
+                os.path.join(
+                    self.base_path, "url_cache",
+                    media_id[0:2], media_id[2:4],
+                ),
+                os.path.join(
+                    self.base_path, "url_cache",
+                    media_id[0:2],
+                ),
+            ]
+
+    def url_cache_thumbnail_rel(self, media_id, width, height, content_type,
+                                method):
+        # Media id is of the form <DATE><RANDOM_STRING>
+        # E.g.: 2017-09-28-fsdRDt24DS234dsf
+
+        top_level_type, sub_type = content_type.split("/")
+        file_name = "%i-%i-%s-%s-%s" % (
+            width, height, top_level_type, sub_type, method
+        )
+
+        if NEW_FORMAT_ID_RE.match(media_id):
+            return os.path.join(
+                "url_cache_thumbnails",
+                media_id[:10], media_id[11:],
+                file_name
+            )
+        else:
+            return os.path.join(
+                "url_cache_thumbnails",
+                media_id[0:2], media_id[2:4], media_id[4:],
+                file_name
+            )
+
+    url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
+
+    def url_cache_thumbnail_directory(self, media_id):
+        # Media id is of the form <DATE><RANDOM_STRING>
+        # E.g.: 2017-09-28-fsdRDt24DS234dsf
+
+        if NEW_FORMAT_ID_RE.match(media_id):
+            return os.path.join(
+                self.base_path, "url_cache_thumbnails",
+                media_id[:10], media_id[11:],
+            )
+        else:
+            return os.path.join(
+                self.base_path, "url_cache_thumbnails",
+                media_id[0:2], media_id[2:4], media_id[4:],
+            )
+
+    def url_cache_thumbnail_dirs_to_delete(self, media_id):
+        "The dirs to try and remove if we delete the media_id thumbnails"
+        # Media id is of the form <DATE><RANDOM_STRING>
+        # E.g.: 2017-09-28-fsdRDt24DS234dsf
+        if NEW_FORMAT_ID_RE.match(media_id):
+            return [
+                os.path.join(
+                    self.base_path, "url_cache_thumbnails",
+                    media_id[:10], media_id[11:],
+                ),
+                os.path.join(
+                    self.base_path, "url_cache_thumbnails",
+                    media_id[:10],
+                ),
+            ]
+        else:
+            return [
+                os.path.join(
+                    self.base_path, "url_cache_thumbnails",
+                    media_id[0:2], media_id[2:4], media_id[4:],
+                ),
+                os.path.join(
+                    self.base_path, "url_cache_thumbnails",
+                    media_id[0:2], media_id[2:4],
+                ),
+                os.path.join(
+                    self.base_path, "url_cache_thumbnails",
+                    media_id[0:2],
+                ),
+            ]
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 3cbeca503c..9800ce7581 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,26 +14,32 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer, threads
+import twisted.internet.error
+import twisted.web.http
+from twisted.web.resource import Resource
+
+from ._base import respond_404, FileInfo, respond_with_responder
 from .upload_resource import UploadResource
 from .download_resource import DownloadResource
 from .thumbnail_resource import ThumbnailResource
 from .identicon_resource import IdenticonResource
 from .preview_url_resource import PreviewUrlResource
 from .filepath import MediaFilePaths
-
-from twisted.web.resource import Resource
-
 from .thumbnailer import Thumbnailer
+from .storage_provider import StorageProviderWrapper
+from .media_storage import MediaStorage
 
 from synapse.http.matrixfederationclient import MatrixFederationHttpClient
 from synapse.util.stringutils import random_string
-from synapse.api.errors import SynapseError
-
-from twisted.internet import defer, threads
+from synapse.api.errors import (
+    SynapseError, HttpResponseException, NotFoundError, FederationDeniedError,
+)
 
 from synapse.util.async import Linearizer
 from synapse.util.stringutils import is_ascii
-from synapse.util.logcontext import preserve_context_over_fn
+from synapse.util.logcontext import make_deferred_yieldable
+from synapse.util.retryutils import NotRetryingDestination
 
 import os
 import errno
@@ -40,12 +47,12 @@ import shutil
 
 import cgi
 import logging
-import urlparse
+from six.moves.urllib import parse as urlparse
 
 logger = logging.getLogger(__name__)
 
 
-UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
+UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
 
 
 class MediaRepository(object):
@@ -57,46 +64,90 @@ class MediaRepository(object):
         self.store = hs.get_datastore()
         self.max_upload_size = hs.config.max_upload_size
         self.max_image_pixels = hs.config.max_image_pixels
-        self.filepaths = MediaFilePaths(hs.config.media_store_path)
+
+        self.primary_base_path = hs.config.media_store_path
+        self.filepaths = MediaFilePaths(self.primary_base_path)
+
         self.dynamic_thumbnails = hs.config.dynamic_thumbnails
         self.thumbnail_requirements = hs.config.thumbnail_requirements
 
         self.remote_media_linearizer = Linearizer(name="media_remote")
 
         self.recently_accessed_remotes = set()
+        self.recently_accessed_locals = set()
+
+        self.federation_domain_whitelist = hs.config.federation_domain_whitelist
+
+        # List of StorageProviders where we should search for media and
+        # potentially upload to.
+        storage_providers = []
+
+        for clz, provider_config, wrapper_config in hs.config.media_storage_providers:
+            backend = clz(hs, provider_config)
+            provider = StorageProviderWrapper(
+                backend,
+                store_local=wrapper_config.store_local,
+                store_remote=wrapper_config.store_remote,
+                store_synchronous=wrapper_config.store_synchronous,
+            )
+            storage_providers.append(provider)
+
+        self.media_storage = MediaStorage(
+            self.primary_base_path, self.filepaths, storage_providers,
+        )
 
         self.clock.looping_call(
-            self._update_recently_accessed_remotes,
-            UPDATE_RECENTLY_ACCESSED_REMOTES_TS
+            self._update_recently_accessed,
+            UPDATE_RECENTLY_ACCESSED_TS,
         )
 
     @defer.inlineCallbacks
-    def _update_recently_accessed_remotes(self):
-        media = self.recently_accessed_remotes
+    def _update_recently_accessed(self):
+        remote_media = self.recently_accessed_remotes
         self.recently_accessed_remotes = set()
 
+        local_media = self.recently_accessed_locals
+        self.recently_accessed_locals = set()
+
         yield self.store.update_cached_last_access_time(
-            media, self.clock.time_msec()
+            local_media, remote_media, self.clock.time_msec()
         )
 
-    @staticmethod
-    def _makedirs(filepath):
-        dirname = os.path.dirname(filepath)
-        if not os.path.exists(dirname):
-            os.makedirs(dirname)
+    def mark_recently_accessed(self, server_name, media_id):
+        """Mark the given media as recently accessed.
+
+        Args:
+            server_name (str|None): Origin server of media, or None if local
+            media_id (str): The media ID of the content
+        """
+        if server_name:
+            self.recently_accessed_remotes.add((server_name, media_id))
+        else:
+            self.recently_accessed_locals.add(media_id)
 
     @defer.inlineCallbacks
     def create_content(self, media_type, upload_name, content, content_length,
                        auth_user):
+        """Store uploaded content for a local user and return the mxc URL
+
+        Args:
+            media_type(str): The content type of the file
+            upload_name(str): The name of the file
+            content: A file like object that is the content to store
+            content_length(int): The length of the content
+            auth_user(str): The user_id of the uploader
+
+        Returns:
+            Deferred[str]: The mxc url of the stored content
+        """
         media_id = random_string(24)
 
-        fname = self.filepaths.local_media_filepath(media_id)
-        self._makedirs(fname)
+        file_info = FileInfo(
+            server_name=None,
+            file_id=media_id,
+        )
 
-        # This shouldn't block for very long because the content will have
-        # already been uploaded at this point.
-        with open(fname, "wb") as f:
-            f.write(content)
+        fname = yield self.media_storage.store_file(content, file_info)
 
         logger.info("Stored local media in file %r", fname)
 
@@ -108,104 +159,275 @@ class MediaRepository(object):
             media_length=content_length,
             user_id=auth_user,
         )
-        media_info = {
-            "media_type": media_type,
-            "media_length": content_length,
-        }
 
-        yield self._generate_local_thumbnails(media_id, media_info)
+        yield self._generate_thumbnails(
+            None, media_id, media_id, media_type,
+        )
 
         defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
 
     @defer.inlineCallbacks
-    def get_remote_media(self, server_name, media_id):
+    def get_local_media(self, request, media_id, name):
+        """Responds to reqests for local media, if exists, or returns 404.
+
+        Args:
+            request(twisted.web.http.Request)
+            media_id (str): The media ID of the content. (This is the same as
+                the file_id for local content.)
+            name (str|None): Optional name that, if specified, will be used as
+                the filename in the Content-Disposition header of the response.
+
+        Returns:
+            Deferred: Resolves once a response has successfully been written
+                to request
+        """
+        media_info = yield self.store.get_local_media(media_id)
+        if not media_info or media_info["quarantined_by"]:
+            respond_404(request)
+            return
+
+        self.mark_recently_accessed(None, media_id)
+
+        media_type = media_info["media_type"]
+        media_length = media_info["media_length"]
+        upload_name = name if name else media_info["upload_name"]
+        url_cache = media_info["url_cache"]
+
+        file_info = FileInfo(
+            None, media_id,
+            url_cache=url_cache,
+        )
+
+        responder = yield self.media_storage.fetch_media(file_info)
+        yield respond_with_responder(
+            request, responder, media_type, media_length, upload_name,
+        )
+
+    @defer.inlineCallbacks
+    def get_remote_media(self, request, server_name, media_id, name):
+        """Respond to requests for remote media.
+
+        Args:
+            request(twisted.web.http.Request)
+            server_name (str): Remote server_name where the media originated.
+            media_id (str): The media ID of the content (as defined by the
+                remote server).
+            name (str|None): Optional name that, if specified, will be used as
+                the filename in the Content-Disposition header of the response.
+
+        Returns:
+            Deferred: Resolves once a response has successfully been written
+                to request
+        """
+        if (
+            self.federation_domain_whitelist is not None and
+            server_name not in self.federation_domain_whitelist
+        ):
+            raise FederationDeniedError(server_name)
+
+        self.mark_recently_accessed(server_name, media_id)
+
+        # We linearize here to ensure that we don't try and download remote
+        # media multiple times concurrently
+        key = (server_name, media_id)
+        with (yield self.remote_media_linearizer.queue(key)):
+            responder, media_info = yield self._get_remote_media_impl(
+                server_name, media_id,
+            )
+
+        # We deliberately stream the file outside the lock
+        if responder:
+            media_type = media_info["media_type"]
+            media_length = media_info["media_length"]
+            upload_name = name if name else media_info["upload_name"]
+            yield respond_with_responder(
+                request, responder, media_type, media_length, upload_name,
+            )
+        else:
+            respond_404(request)
+
+    @defer.inlineCallbacks
+    def get_remote_media_info(self, server_name, media_id):
+        """Gets the media info associated with the remote file, downloading
+        if necessary.
+
+        Args:
+            server_name (str): Remote server_name where the media originated.
+            media_id (str): The media ID of the content (as defined by the
+                remote server).
+
+        Returns:
+            Deferred[dict]: The media_info of the file
+        """
+        if (
+            self.federation_domain_whitelist is not None and
+            server_name not in self.federation_domain_whitelist
+        ):
+            raise FederationDeniedError(server_name)
+
+        # We linearize here to ensure that we don't try and download remote
+        # media multiple times concurrently
         key = (server_name, media_id)
         with (yield self.remote_media_linearizer.queue(key)):
-            media_info = yield self._get_remote_media_impl(server_name, media_id)
+            responder, media_info = yield self._get_remote_media_impl(
+                server_name, media_id,
+            )
+
+        # Ensure we actually use the responder so that it releases resources
+        if responder:
+            with responder:
+                pass
+
         defer.returnValue(media_info)
 
     @defer.inlineCallbacks
     def _get_remote_media_impl(self, server_name, media_id):
+        """Looks for media in local cache, if not there then attempt to
+        download from remote server.
+
+        Args:
+            server_name (str): Remote server_name where the media originated.
+            media_id (str): The media ID of the content (as defined by the
+                remote server).
+
+        Returns:
+            Deferred[(Responder, media_info)]
+        """
         media_info = yield self.store.get_cached_remote_media(
             server_name, media_id
         )
-        if not media_info:
-            media_info = yield self._download_remote_file(
-                server_name, media_id
-            )
+
+        # file_id is the ID we use to track the file locally. If we've already
+        # seen the file then reuse the existing ID, otherwise genereate a new
+        # one.
+        if media_info:
+            file_id = media_info["filesystem_id"]
         else:
-            self.recently_accessed_remotes.add((server_name, media_id))
-            yield self.store.update_cached_last_access_time(
-                [(server_name, media_id)], self.clock.time_msec()
-            )
-        defer.returnValue(media_info)
+            file_id = random_string(24)
 
-    @defer.inlineCallbacks
-    def _download_remote_file(self, server_name, media_id):
-        file_id = random_string(24)
+        file_info = FileInfo(server_name, file_id)
+
+        # If we have an entry in the DB, try and look for it
+        if media_info:
+            if media_info["quarantined_by"]:
+                logger.info("Media is quarantined")
+                raise NotFoundError()
+
+            responder = yield self.media_storage.fetch_media(file_info)
+            if responder:
+                defer.returnValue((responder, media_info))
+
+        # Failed to find the file anywhere, lets download it.
 
-        fname = self.filepaths.remote_media_filepath(
-            server_name, file_id
+        media_info = yield self._download_remote_file(
+            server_name, media_id, file_id
         )
-        self._makedirs(fname)
 
-        try:
-            with open(fname, "wb") as f:
-                request_path = "/".join((
-                    "/_matrix/media/v1/download", server_name, media_id,
-                ))
+        responder = yield self.media_storage.fetch_media(file_info)
+        defer.returnValue((responder, media_info))
+
+    @defer.inlineCallbacks
+    def _download_remote_file(self, server_name, media_id, file_id):
+        """Attempt to download the remote file from the given server name,
+        using the given file_id as the local id.
+
+        Args:
+            server_name (str): Originating server
+            media_id (str): The media ID of the content (as defined by the
+                remote server). This is different than the file_id, which is
+                locally generated.
+            file_id (str): Local file ID
+
+        Returns:
+            Deferred[MediaInfo]
+        """
+
+        file_info = FileInfo(
+            server_name=server_name,
+            file_id=file_id,
+        )
+
+        with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+            request_path = "/".join((
+                "/_matrix/media/v1/download", server_name, media_id,
+            ))
+            try:
+                length, headers = yield self.client.get_file(
+                    server_name, request_path, output_stream=f,
+                    max_size=self.max_upload_size, args={
+                        # tell the remote server to 404 if it doesn't
+                        # recognise the server_name, to make sure we don't
+                        # end up with a routing loop.
+                        "allow_remote": "false",
+                    }
+                )
+            except twisted.internet.error.DNSLookupError as e:
+                logger.warn("HTTP error fetching remote media %s/%s: %r",
+                            server_name, media_id, e)
+                raise NotFoundError()
+
+            except HttpResponseException as e:
+                logger.warn("HTTP error fetching remote media %s/%s: %s",
+                            server_name, media_id, e.response)
+                if e.code == twisted.web.http.NOT_FOUND:
+                    raise SynapseError.from_http_response_exception(e)
+                raise SynapseError(502, "Failed to fetch remote media")
+
+            except SynapseError:
+                logger.exception("Failed to fetch remote media %s/%s",
+                                 server_name, media_id)
+                raise
+            except NotRetryingDestination:
+                logger.warn("Not retrying destination %r", server_name)
+                raise SynapseError(502, "Failed to fetch remote media")
+            except Exception:
+                logger.exception("Failed to fetch remote media %s/%s",
+                                 server_name, media_id)
+                raise SynapseError(502, "Failed to fetch remote media")
+
+            yield finish()
+
+        media_type = headers["Content-Type"][0]
+
+        time_now_ms = self.clock.time_msec()
+
+        content_disposition = headers.get("Content-Disposition", None)
+        if content_disposition:
+            _, params = cgi.parse_header(content_disposition[0],)
+            upload_name = None
+
+            # First check if there is a valid UTF-8 filename
+            upload_name_utf8 = params.get("filename*", None)
+            if upload_name_utf8:
+                if upload_name_utf8.lower().startswith("utf-8''"):
+                    upload_name = upload_name_utf8[7:]
+
+            # If there isn't check for an ascii name.
+            if not upload_name:
+                upload_name_ascii = params.get("filename", None)
+                if upload_name_ascii and is_ascii(upload_name_ascii):
+                    upload_name = upload_name_ascii
+
+            if upload_name:
+                upload_name = urlparse.unquote(upload_name)
                 try:
-                    length, headers = yield self.client.get_file(
-                        server_name, request_path, output_stream=f,
-                        max_size=self.max_upload_size,
-                    )
-                except Exception as e:
-                    logger.warn("Failed to fetch remoted media %r", e)
-                    raise SynapseError(502, "Failed to fetch remoted media")
-
-            media_type = headers["Content-Type"][0]
-            time_now_ms = self.clock.time_msec()
-
-            content_disposition = headers.get("Content-Disposition", None)
-            if content_disposition:
-                _, params = cgi.parse_header(content_disposition[0],)
-                upload_name = None
-
-                # First check if there is a valid UTF-8 filename
-                upload_name_utf8 = params.get("filename*", None)
-                if upload_name_utf8:
-                    if upload_name_utf8.lower().startswith("utf-8''"):
-                        upload_name = upload_name_utf8[7:]
-
-                # If there isn't check for an ascii name.
-                if not upload_name:
-                    upload_name_ascii = params.get("filename", None)
-                    if upload_name_ascii and is_ascii(upload_name_ascii):
-                        upload_name = upload_name_ascii
-
-                if upload_name:
-                    upload_name = urlparse.unquote(upload_name)
-                    try:
-                        upload_name = upload_name.decode("utf-8")
-                    except UnicodeDecodeError:
-                        upload_name = None
-            else:
-                upload_name = None
-
-            logger.info("Stored remote media in file %r", fname)
-
-            yield self.store.store_cached_remote_media(
-                origin=server_name,
-                media_id=media_id,
-                media_type=media_type,
-                time_now_ms=self.clock.time_msec(),
-                upload_name=upload_name,
-                media_length=length,
-                filesystem_id=file_id,
-            )
-        except:
-            os.remove(fname)
-            raise
+                    upload_name = upload_name.decode("utf-8")
+                except UnicodeDecodeError:
+                    upload_name = None
+        else:
+            upload_name = None
+
+        logger.info("Stored remote media in file %r", fname)
+
+        yield self.store.store_cached_remote_media(
+            origin=server_name,
+            media_id=media_id,
+            media_type=media_type,
+            time_now_ms=self.clock.time_msec(),
+            upload_name=upload_name,
+            media_length=length,
+            filesystem_id=file_id,
+        )
 
         media_info = {
             "media_type": media_type,
@@ -215,8 +437,8 @@ class MediaRepository(object):
             "filesystem_id": file_id,
         }
 
-        yield self._generate_remote_thumbnails(
-            server_name, media_id, media_info
+        yield self._generate_thumbnails(
+            server_name, media_id, file_id, media_type,
         )
 
         defer.returnValue(media_info)
@@ -224,9 +446,8 @@ class MediaRepository(object):
     def _get_thumbnail_requirements(self, media_type):
         return self.thumbnail_requirements.get(media_type, ())
 
-    def _generate_thumbnail(self, input_path, t_path, t_width, t_height,
+    def _generate_thumbnail(self, thumbnailer, t_width, t_height,
                             t_method, t_type):
-        thumbnailer = Thumbnailer(input_path)
         m_width = thumbnailer.width
         m_height = thumbnailer.height
 
@@ -238,69 +459,126 @@ class MediaRepository(object):
             return
 
         if t_method == "crop":
-            t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
+            t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
         elif t_method == "scale":
-            t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
+            t_width, t_height = thumbnailer.aspect(t_width, t_height)
+            t_width = min(m_width, t_width)
+            t_height = min(m_height, t_height)
+            t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
         else:
-            t_len = None
+            t_byte_source = None
 
-        return t_len
+        return t_byte_source
 
     @defer.inlineCallbacks
     def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
-                                       t_method, t_type):
-        input_path = self.filepaths.local_media_filepath(media_id)
-
-        t_path = self.filepaths.local_media_thumbnail(
-            media_id, t_width, t_height, t_type, t_method
-        )
-        self._makedirs(t_path)
+                                       t_method, t_type, url_cache):
+        input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
+            None, media_id, url_cache=url_cache,
+        ))
 
-        t_len = yield preserve_context_over_fn(
-            threads.deferToThread,
+        thumbnailer = Thumbnailer(input_path)
+        t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
             self._generate_thumbnail,
-            input_path, t_path, t_width, t_height, t_method, t_type
-        )
+            thumbnailer, t_width, t_height, t_method, t_type
+        ))
+
+        if t_byte_source:
+            try:
+                file_info = FileInfo(
+                    server_name=None,
+                    file_id=media_id,
+                    url_cache=url_cache,
+                    thumbnail=True,
+                    thumbnail_width=t_width,
+                    thumbnail_height=t_height,
+                    thumbnail_method=t_method,
+                    thumbnail_type=t_type,
+                )
+
+                output_path = yield self.media_storage.store_file(
+                    t_byte_source, file_info,
+                )
+            finally:
+                t_byte_source.close()
+
+            logger.info("Stored thumbnail in file %r", output_path)
+
+            t_len = os.path.getsize(output_path)
 
-        if t_len:
             yield self.store.store_local_thumbnail(
                 media_id, t_width, t_height, t_type, t_method, t_len
             )
 
-            defer.returnValue(t_path)
+            defer.returnValue(output_path)
 
     @defer.inlineCallbacks
     def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
                                         t_width, t_height, t_method, t_type):
-        input_path = self.filepaths.remote_media_filepath(server_name, file_id)
-
-        t_path = self.filepaths.remote_media_thumbnail(
-            server_name, file_id, t_width, t_height, t_type, t_method
-        )
-        self._makedirs(t_path)
+        input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
+            server_name, file_id, url_cache=False,
+        ))
 
-        t_len = yield preserve_context_over_fn(
-            threads.deferToThread,
+        thumbnailer = Thumbnailer(input_path)
+        t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
             self._generate_thumbnail,
-            input_path, t_path, t_width, t_height, t_method, t_type
-        )
+            thumbnailer, t_width, t_height, t_method, t_type
+        ))
+
+        if t_byte_source:
+            try:
+                file_info = FileInfo(
+                    server_name=server_name,
+                    file_id=media_id,
+                    thumbnail=True,
+                    thumbnail_width=t_width,
+                    thumbnail_height=t_height,
+                    thumbnail_method=t_method,
+                    thumbnail_type=t_type,
+                )
+
+                output_path = yield self.media_storage.store_file(
+                    t_byte_source, file_info,
+                )
+            finally:
+                t_byte_source.close()
+
+            logger.info("Stored thumbnail in file %r", output_path)
+
+            t_len = os.path.getsize(output_path)
 
-        if t_len:
             yield self.store.store_remote_media_thumbnail(
                 server_name, media_id, file_id,
                 t_width, t_height, t_type, t_method, t_len
             )
 
-            defer.returnValue(t_path)
+            defer.returnValue(output_path)
 
     @defer.inlineCallbacks
-    def _generate_local_thumbnails(self, media_id, media_info):
-        media_type = media_info["media_type"]
+    def _generate_thumbnails(self, server_name, media_id, file_id, media_type,
+                             url_cache=False):
+        """Generate and store thumbnails for an image.
+
+        Args:
+            server_name (str|None): The server name if remote media, else None if local
+            media_id (str): The media ID of the content. (This is the same as
+                the file_id for local content)
+            file_id (str): Local file ID
+            media_type (str): The content type of the file
+            url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
+                used exclusively by the url previewer
+
+        Returns:
+            Deferred[dict]: Dict with "width" and "height" keys of original image
+        """
         requirements = self._get_thumbnail_requirements(media_type)
         if not requirements:
             return
 
-        input_path = self.filepaths.local_media_filepath(media_id)
+        input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
+            server_name, file_id, url_cache=url_cache,
+        ))
+
         thumbnailer = Thumbnailer(input_path)
         m_width = thumbnailer.width
         m_height = thumbnailer.height
@@ -312,125 +590,68 @@ class MediaRepository(object):
             )
             return
 
-        local_thumbnails = []
-
-        def generate_thumbnails():
-            scales = set()
-            crops = set()
-            for r_width, r_height, r_method, r_type in requirements:
-                if r_method == "scale":
-                    t_width, t_height = thumbnailer.aspect(r_width, r_height)
-                    scales.add((
-                        min(m_width, t_width), min(m_height, t_height), r_type,
-                    ))
-                elif r_method == "crop":
-                    crops.add((r_width, r_height, r_type))
-
-            for t_width, t_height, t_type in scales:
-                t_method = "scale"
-                t_path = self.filepaths.local_media_thumbnail(
-                    media_id, t_width, t_height, t_type, t_method
-                )
-                self._makedirs(t_path)
-                t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
-
-                local_thumbnails.append((
-                    media_id, t_width, t_height, t_type, t_method, t_len
+        # We deduplicate the thumbnail sizes by ignoring the cropped versions if
+        # they have the same dimensions of a scaled one.
+        thumbnails = {}
+        for r_width, r_height, r_method, r_type in requirements:
+            if r_method == "crop":
+                thumbnails.setdefault((r_width, r_height, r_type), r_method)
+            elif r_method == "scale":
+                t_width, t_height = thumbnailer.aspect(r_width, r_height)
+                t_width = min(m_width, t_width)
+                t_height = min(m_height, t_height)
+                thumbnails[(t_width, t_height, r_type)] = r_method
+
+        # Now we generate the thumbnails for each dimension, store it
+        for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
+            # Generate the thumbnail
+            if t_method == "crop":
+                t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
+                    thumbnailer.crop,
+                    t_width, t_height, t_type,
                 ))
-
-            for t_width, t_height, t_type in crops:
-                if (t_width, t_height, t_type) in scales:
-                    # If the aspect ratio of the cropped thumbnail matches a purely
-                    # scaled one then there is no point in calculating a separate
-                    # thumbnail.
-                    continue
-                t_method = "crop"
-                t_path = self.filepaths.local_media_thumbnail(
-                    media_id, t_width, t_height, t_type, t_method
-                )
-                self._makedirs(t_path)
-                t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
-                local_thumbnails.append((
-                    media_id, t_width, t_height, t_type, t_method, t_len
+            elif t_method == "scale":
+                t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
+                    thumbnailer.scale,
+                    t_width, t_height, t_type,
                 ))
+            else:
+                logger.error("Unrecognized method: %r", t_method)
+                continue
+
+            if not t_byte_source:
+                continue
+
+            try:
+                file_info = FileInfo(
+                    server_name=server_name,
+                    file_id=file_id,
+                    thumbnail=True,
+                    thumbnail_width=t_width,
+                    thumbnail_height=t_height,
+                    thumbnail_method=t_method,
+                    thumbnail_type=t_type,
+                    url_cache=url_cache,
+                )
 
-        yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
-
-        for l in local_thumbnails:
-            yield self.store.store_local_thumbnail(*l)
-
-        defer.returnValue({
-            "width": m_width,
-            "height": m_height,
-        })
-
-    @defer.inlineCallbacks
-    def _generate_remote_thumbnails(self, server_name, media_id, media_info):
-        media_type = media_info["media_type"]
-        file_id = media_info["filesystem_id"]
-        requirements = self._get_thumbnail_requirements(media_type)
-        if not requirements:
-            return
-
-        remote_thumbnails = []
+                output_path = yield self.media_storage.store_file(
+                    t_byte_source, file_info,
+                )
+            finally:
+                t_byte_source.close()
 
-        input_path = self.filepaths.remote_media_filepath(server_name, file_id)
-        thumbnailer = Thumbnailer(input_path)
-        m_width = thumbnailer.width
-        m_height = thumbnailer.height
+            t_len = os.path.getsize(output_path)
 
-        def generate_thumbnails():
-            if m_width * m_height >= self.max_image_pixels:
-                logger.info(
-                    "Image too large to thumbnail %r x %r > %r",
-                    m_width, m_height, self.max_image_pixels
-                )
-                return
-
-            scales = set()
-            crops = set()
-            for r_width, r_height, r_method, r_type in requirements:
-                if r_method == "scale":
-                    t_width, t_height = thumbnailer.aspect(r_width, r_height)
-                    scales.add((
-                        min(m_width, t_width), min(m_height, t_height), r_type,
-                    ))
-                elif r_method == "crop":
-                    crops.add((r_width, r_height, r_type))
-
-            for t_width, t_height, t_type in scales:
-                t_method = "scale"
-                t_path = self.filepaths.remote_media_thumbnail(
-                    server_name, file_id, t_width, t_height, t_type, t_method
-                )
-                self._makedirs(t_path)
-                t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
-                remote_thumbnails.append([
+            # Write to database
+            if server_name:
+                yield self.store.store_remote_media_thumbnail(
                     server_name, media_id, file_id,
                     t_width, t_height, t_type, t_method, t_len
-                ])
-
-            for t_width, t_height, t_type in crops:
-                if (t_width, t_height, t_type) in scales:
-                    # If the aspect ratio of the cropped thumbnail matches a purely
-                    # scaled one then there is no point in calculating a separate
-                    # thumbnail.
-                    continue
-                t_method = "crop"
-                t_path = self.filepaths.remote_media_thumbnail(
-                    server_name, file_id, t_width, t_height, t_type, t_method
                 )
-                self._makedirs(t_path)
-                t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
-                remote_thumbnails.append([
-                    server_name, media_id, file_id,
-                    t_width, t_height, t_type, t_method, t_len
-                ])
-
-        yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
-
-        for r in remote_thumbnails:
-            yield self.store.store_remote_media_thumbnail(*r)
+            else:
+                yield self.store.store_local_thumbnail(
+                    media_id, t_width, t_height, t_type, t_method, t_len
+                )
 
         defer.returnValue({
             "width": m_width,
@@ -451,6 +672,8 @@ class MediaRepository(object):
 
             logger.info("Deleting: %r", key)
 
+            # TODO: Should we delete from the backup store
+
             with (yield self.remote_media_linearizer.queue(key)):
                 full_path = self.filepaths.remote_media_filepath(origin, file_id)
                 try:
@@ -525,7 +748,11 @@ class MediaRepositoryResource(Resource):
 
         self.putChild("upload", UploadResource(hs, media_repo))
         self.putChild("download", DownloadResource(hs, media_repo))
-        self.putChild("thumbnail", ThumbnailResource(hs, media_repo))
+        self.putChild("thumbnail", ThumbnailResource(
+            hs, media_repo, media_repo.media_storage,
+        ))
         self.putChild("identicon", IdenticonResource())
         if hs.config.url_preview_enabled:
-            self.putChild("preview_url", PreviewUrlResource(hs, media_repo))
+            self.putChild("preview_url", PreviewUrlResource(
+                hs, media_repo, media_repo.media_storage,
+            ))
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
new file mode 100644
index 0000000000..d23fe10b07
--- /dev/null
+++ b/synapse/rest/media/v1/media_storage.py
@@ -0,0 +1,263 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vecotr Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer, threads
+from twisted.protocols.basic import FileSender
+
+import six
+
+from ._base import Responder
+
+from synapse.util.file_consumer import BackgroundFileConsumer
+from synapse.util.logcontext import make_deferred_yieldable
+
+import contextlib
+import os
+import logging
+import shutil
+import sys
+
+
+logger = logging.getLogger(__name__)
+
+
+class MediaStorage(object):
+    """Responsible for storing/fetching files from local sources.
+
+    Args:
+        local_media_directory (str): Base path where we store media on disk
+        filepaths (MediaFilePaths)
+        storage_providers ([StorageProvider]): List of StorageProvider that are
+            used to fetch and store files.
+    """
+
+    def __init__(self, local_media_directory, filepaths, storage_providers):
+        self.local_media_directory = local_media_directory
+        self.filepaths = filepaths
+        self.storage_providers = storage_providers
+
+    @defer.inlineCallbacks
+    def store_file(self, source, file_info):
+        """Write `source` to the on disk media store, and also any other
+        configured storage providers
+
+        Args:
+            source: A file like object that should be written
+            file_info (FileInfo): Info about the file to store
+
+        Returns:
+            Deferred[str]: the file path written to in the primary media store
+        """
+
+        with self.store_into_file(file_info) as (f, fname, finish_cb):
+            # Write to the main repository
+            yield make_deferred_yieldable(threads.deferToThread(
+                _write_file_synchronously, source, f,
+            ))
+            yield finish_cb()
+
+        defer.returnValue(fname)
+
+    @contextlib.contextmanager
+    def store_into_file(self, file_info):
+        """Context manager used to get a file like object to write into, as
+        described by file_info.
+
+        Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
+        like object that can be written to, fname is the absolute path of file
+        on disk, and finish_cb is a function that returns a Deferred.
+
+        fname can be used to read the contents from after upload, e.g. to
+        generate thumbnails.
+
+        finish_cb must be called and waited on after the file has been
+        successfully been written to. Should not be called if there was an
+        error.
+
+        Args:
+            file_info (FileInfo): Info about the file to store
+
+        Example:
+
+            with media_storage.store_into_file(info) as (f, fname, finish_cb):
+                # .. write into f ...
+                yield finish_cb()
+        """
+
+        path = self._file_info_to_path(file_info)
+        fname = os.path.join(self.local_media_directory, path)
+
+        dirname = os.path.dirname(fname)
+        if not os.path.exists(dirname):
+            os.makedirs(dirname)
+
+        finished_called = [False]
+
+        @defer.inlineCallbacks
+        def finish():
+            for provider in self.storage_providers:
+                yield provider.store_file(path, file_info)
+
+            finished_called[0] = True
+
+        try:
+            with open(fname, "wb") as f:
+                yield f, fname, finish
+        except Exception:
+            t, v, tb = sys.exc_info()
+            try:
+                os.remove(fname)
+            except Exception:
+                pass
+            six.reraise(t, v, tb)
+
+        if not finished_called:
+            raise Exception("Finished callback not called")
+
+    @defer.inlineCallbacks
+    def fetch_media(self, file_info):
+        """Attempts to fetch media described by file_info from the local cache
+        and configured storage providers.
+
+        Args:
+            file_info (FileInfo)
+
+        Returns:
+            Deferred[Responder|None]: Returns a Responder if the file was found,
+                otherwise None.
+        """
+
+        path = self._file_info_to_path(file_info)
+        local_path = os.path.join(self.local_media_directory, path)
+        if os.path.exists(local_path):
+            defer.returnValue(FileResponder(open(local_path, "rb")))
+
+        for provider in self.storage_providers:
+            res = yield provider.fetch(path, file_info)
+            if res:
+                defer.returnValue(res)
+
+        defer.returnValue(None)
+
+    @defer.inlineCallbacks
+    def ensure_media_is_in_local_cache(self, file_info):
+        """Ensures that the given file is in the local cache. Attempts to
+        download it from storage providers if it isn't.
+
+        Args:
+            file_info (FileInfo)
+
+        Returns:
+            Deferred[str]: Full path to local file
+        """
+        path = self._file_info_to_path(file_info)
+        local_path = os.path.join(self.local_media_directory, path)
+        if os.path.exists(local_path):
+            defer.returnValue(local_path)
+
+        dirname = os.path.dirname(local_path)
+        if not os.path.exists(dirname):
+            os.makedirs(dirname)
+
+        for provider in self.storage_providers:
+            res = yield provider.fetch(path, file_info)
+            if res:
+                with res:
+                    consumer = BackgroundFileConsumer(open(local_path, "w"))
+                    yield res.write_to_consumer(consumer)
+                    yield consumer.wait()
+                defer.returnValue(local_path)
+
+        raise Exception("file could not be found")
+
+    def _file_info_to_path(self, file_info):
+        """Converts file_info into a relative path.
+
+        The path is suitable for storing files under a directory, e.g. used to
+        store files on local FS under the base media repository directory.
+
+        Args:
+            file_info (FileInfo)
+
+        Returns:
+            str
+        """
+        if file_info.url_cache:
+            if file_info.thumbnail:
+                return self.filepaths.url_cache_thumbnail_rel(
+                    media_id=file_info.file_id,
+                    width=file_info.thumbnail_width,
+                    height=file_info.thumbnail_height,
+                    content_type=file_info.thumbnail_type,
+                    method=file_info.thumbnail_method,
+                )
+            return self.filepaths.url_cache_filepath_rel(file_info.file_id)
+
+        if file_info.server_name:
+            if file_info.thumbnail:
+                return self.filepaths.remote_media_thumbnail_rel(
+                    server_name=file_info.server_name,
+                    file_id=file_info.file_id,
+                    width=file_info.thumbnail_width,
+                    height=file_info.thumbnail_height,
+                    content_type=file_info.thumbnail_type,
+                    method=file_info.thumbnail_method
+                )
+            return self.filepaths.remote_media_filepath_rel(
+                file_info.server_name, file_info.file_id,
+            )
+
+        if file_info.thumbnail:
+            return self.filepaths.local_media_thumbnail_rel(
+                media_id=file_info.file_id,
+                width=file_info.thumbnail_width,
+                height=file_info.thumbnail_height,
+                content_type=file_info.thumbnail_type,
+                method=file_info.thumbnail_method
+            )
+        return self.filepaths.local_media_filepath_rel(
+            file_info.file_id,
+        )
+
+
+def _write_file_synchronously(source, dest):
+    """Write `source` to the file like `dest` synchronously. Should be called
+    from a thread.
+
+    Args:
+        source: A file like object that's to be written
+        dest: A file like object to be written to
+    """
+    source.seek(0)  # Ensure we read from the start of the file
+    shutil.copyfileobj(source, dest)
+
+
+class FileResponder(Responder):
+    """Wraps an open file that can be sent to a request.
+
+    Args:
+        open_file (file): A file like object to be streamed ot the client,
+            is closed when finished streaming.
+    """
+    def __init__(self, open_file):
+        self.open_file = open_file
+
+    def write_to_consumer(self, consumer):
+        return make_deferred_yieldable(
+            FileSender().beginFileTransfer(self.open_file, consumer)
+        )
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.open_file.close()
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 99760d622f..9290d7946f 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -12,39 +12,47 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import cgi
+import datetime
+import errno
+import fnmatch
+import itertools
+import logging
+import os
+import re
+import shutil
+import sys
+import traceback
+import simplejson as json
+import urlparse
 
 from twisted.web.server import NOT_DONE_YET
 from twisted.internet import defer
 from twisted.web.resource import Resource
 
+from ._base import FileInfo
+
 from synapse.api.errors import (
     SynapseError, Codes,
 )
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
 from synapse.util.stringutils import random_string
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.http.client import SpiderHttpClient
 from synapse.http.server import (
-    request_handler, respond_with_json_bytes
+    request_handler, respond_with_json_bytes,
+    respond_with_json,
 )
 from synapse.util.async import ObservableDeferred
 from synapse.util.stringutils import is_ascii
 
-import os
-import re
-import fnmatch
-import cgi
-import ujson as json
-import urlparse
-import itertools
-
-import logging
 logger = logging.getLogger(__name__)
 
 
 class PreviewUrlResource(Resource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo):
+    def __init__(self, hs, media_repo, media_storage):
         Resource.__init__(self)
 
         self.auth = hs.get_auth()
@@ -56,19 +64,27 @@ class PreviewUrlResource(Resource):
         self.store = hs.get_datastore()
         self.client = SpiderHttpClient(hs)
         self.media_repo = media_repo
+        self.primary_base_path = media_repo.primary_base_path
+        self.media_storage = media_storage
 
         self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
 
-        # simple memory cache mapping urls to OG metadata
-        self.cache = ExpiringCache(
+        # memory cache mapping urls to an ObservableDeferred returning
+        # JSON-encoded OG metadata
+        self._cache = ExpiringCache(
             cache_name="url_previews",
             clock=self.clock,
             # don't spider URLs more often than once an hour
             expiry_ms=60 * 60 * 1000,
         )
-        self.cache.start()
+        self._cache.start()
+
+        self._cleaner_loop = self.clock.looping_call(
+            self._expire_url_cache_data, 10 * 1000
+        )
 
-        self.downloads = {}
+    def render_OPTIONS(self, request):
+        return respond_with_json(request, 200, {}, send_cors=True)
 
     def render_GET(self, request):
         self._async_render_GET(request)
@@ -86,6 +102,7 @@ class PreviewUrlResource(Resource):
         else:
             ts = self.clock.time_msec()
 
+        # XXX: we could move this into _do_preview if we wanted.
         url_tuple = urlparse.urlsplit(url)
         for entry in self.url_preview_url_blacklist:
             match = True
@@ -118,53 +135,62 @@ class PreviewUrlResource(Resource):
                     Codes.UNKNOWN
                 )
 
-        # first check the memory cache - good to handle all the clients on this
-        # HS thundering away to preview the same URL at the same time.
-        og = self.cache.get(url)
-        if og:
-            respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
-            return
+        # the in-memory cache:
+        # * ensures that only one request is active at a time
+        # * takes load off the DB for the thundering herds
+        # * also caches any failures (unlike the DB) so we don't keep
+        #    requesting the same endpoint
+
+        observable = self._cache.get(url)
+
+        if not observable:
+            download = run_in_background(
+                self._do_preview,
+                url, requester.user, ts,
+            )
+            observable = ObservableDeferred(
+                download,
+                consumeErrors=True
+            )
+            self._cache[url] = observable
+        else:
+            logger.info("Returning cached response")
 
-        # then check the URL cache in the DB (which will also provide us with
+        og = yield make_deferred_yieldable(observable.observe())
+        respond_with_json_bytes(request, 200, og, send_cors=True)
+
+    @defer.inlineCallbacks
+    def _do_preview(self, url, user, ts):
+        """Check the db, and download the URL and build a preview
+
+        Args:
+            url (str):
+            user (str):
+            ts (int):
+
+        Returns:
+            Deferred[str]: json-encoded og data
+        """
+        # check the URL cache in the DB (which will also provide us with
         # historical previews, if we have any)
         cache_result = yield self.store.get_url_cache(url, ts)
         if (
             cache_result and
-            cache_result["download_ts"] + cache_result["expires"] > ts and
+            cache_result["expires_ts"] > ts and
             cache_result["response_code"] / 100 == 2
         ):
-            respond_with_json_bytes(
-                request, 200, cache_result["og"].encode('utf-8'),
-                send_cors=True
-            )
+            defer.returnValue(cache_result["og"])
             return
 
-        # Ensure only one download for a given URL is active at a time
-        download = self.downloads.get(url)
-        if download is None:
-            download = self._download_url(url, requester.user)
-            download = ObservableDeferred(
-                download,
-                consumeErrors=True
-            )
-            self.downloads[url] = download
-
-            @download.addBoth
-            def callback(media_info):
-                del self.downloads[url]
-                return media_info
-        media_info = yield download.observe()
-
-        # FIXME: we should probably update our cache now anyway, so that
-        # even if the OG calculation raises, we don't keep hammering on the
-        # remote server.  For now, leave it uncached to aid debugging OG
-        # calculation problems
+        media_info = yield self._download_url(url, user)
 
         logger.debug("got media_info of '%s'" % media_info)
 
         if _is_media(media_info['media_type']):
-            dims = yield self.media_repo._generate_local_thumbnails(
-                media_info['filesystem_id'], media_info
+            file_id = media_info['filesystem_id']
+            dims = yield self.media_repo._generate_thumbnails(
+                None, file_id, file_id, media_info["media_type"],
+                url_cache=True,
             )
 
             og = {
@@ -204,13 +230,15 @@ class PreviewUrlResource(Resource):
             # just rely on the caching on the master request to speed things up.
             if 'og:image' in og and og['og:image']:
                 image_info = yield self._download_url(
-                    _rebase_url(og['og:image'], media_info['uri']), requester.user
+                    _rebase_url(og['og:image'], media_info['uri']), user
                 )
 
                 if _is_media(image_info['media_type']):
                     # TODO: make sure we don't choke on white-on-transparent images
-                    dims = yield self.media_repo._generate_local_thumbnails(
-                        image_info['filesystem_id'], image_info
+                    file_id = image_info['filesystem_id']
+                    dims = yield self.media_repo._generate_thumbnails(
+                        None, file_id, file_id, image_info["media_type"],
+                        url_cache=True,
                     )
                     if dims:
                         og["og:image:width"] = dims['width']
@@ -231,21 +259,20 @@ class PreviewUrlResource(Resource):
 
         logger.debug("Calculated OG for %s as %s" % (url, og))
 
-        # store OG in ephemeral in-memory cache
-        self.cache[url] = og
+        jsonog = json.dumps(og)
 
         # store OG in history-aware DB cache
         yield self.store.store_url_cache(
             url,
             media_info["response_code"],
             media_info["etag"],
-            media_info["expires"],
-            json.dumps(og),
+            media_info["expires"] + media_info["created_ts"],
+            jsonog,
             media_info["filesystem_id"],
             media_info["created_ts"],
         )
 
-        respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
+        defer.returnValue(jsonog)
 
     @defer.inlineCallbacks
     def _download_url(self, url, user):
@@ -253,21 +280,36 @@ class PreviewUrlResource(Resource):
         # we're most likely being explicitly triggered by a human rather than a
         # bot, so are we really a robot?
 
-        # XXX: horrible duplication with base_resource's _download_remote_file()
-        file_id = random_string(24)
+        file_id = datetime.date.today().isoformat() + '_' + random_string(16)
 
-        fname = self.filepaths.local_media_filepath(file_id)
-        self.media_repo._makedirs(fname)
+        file_info = FileInfo(
+            server_name=None,
+            file_id=file_id,
+            url_cache=True,
+        )
 
-        try:
-            with open(fname, "wb") as f:
+        with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+            try:
                 logger.debug("Trying to get url '%s'" % url)
                 length, headers, uri, code = yield self.client.get_file(
                     url, output_stream=f, max_size=self.max_spider_size,
                 )
+            except Exception as e:
                 # FIXME: pass through 404s and other error messages nicely
+                logger.warn("Error downloading %s: %r", url, e)
+                raise SynapseError(
+                    500, "Failed to download content: %s" % (
+                        traceback.format_exception_only(sys.exc_type, e),
+                    ),
+                    Codes.UNKNOWN,
+                )
+            yield finish()
 
-            media_type = headers["Content-Type"][0]
+        try:
+            if "Content-Type" in headers:
+                media_type = headers["Content-Type"][0]
+            else:
+                media_type = "application/octet-stream"
             time_now_ms = self.clock.time_msec()
 
             content_disposition = headers.get("Content-Disposition", None)
@@ -303,14 +345,15 @@ class PreviewUrlResource(Resource):
                 upload_name=download_name,
                 media_length=length,
                 user_id=user,
+                url_cache=url,
             )
 
         except Exception as e:
-            os.remove(fname)
-            raise SynapseError(
-                500, ("Failed to download content: %s" % e),
-                Codes.UNKNOWN
-            )
+            logger.error("Error handling downloaded %s: %r", url, e)
+            # TODO: we really ought to delete the downloaded file in this
+            # case, since we won't have recorded it in the db, and will
+            # therefore not expire it.
+            raise
 
         defer.returnValue({
             "media_type": media_type,
@@ -327,6 +370,95 @@ class PreviewUrlResource(Resource):
             "etag": headers["ETag"][0] if "ETag" in headers else None,
         })
 
+    @defer.inlineCallbacks
+    def _expire_url_cache_data(self):
+        """Clean up expired url cache content, media and thumbnails.
+        """
+        # TODO: Delete from backup media store
+
+        now = self.clock.time_msec()
+
+        logger.info("Running url preview cache expiry")
+
+        if not (yield self.store.has_completed_background_updates()):
+            logger.info("Still running DB updates; skipping expiry")
+            return
+
+        # First we delete expired url cache entries
+        media_ids = yield self.store.get_expired_url_cache(now)
+
+        removed_media = []
+        for media_id in media_ids:
+            fname = self.filepaths.url_cache_filepath(media_id)
+            try:
+                os.remove(fname)
+            except OSError as e:
+                # If the path doesn't exist, meh
+                if e.errno != errno.ENOENT:
+                    logger.warn("Failed to remove media: %r: %s", media_id, e)
+                    continue
+
+            removed_media.append(media_id)
+
+            try:
+                dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
+                for dir in dirs:
+                    os.rmdir(dir)
+            except Exception:
+                pass
+
+        yield self.store.delete_url_cache(removed_media)
+
+        if removed_media:
+            logger.info("Deleted %d entries from url cache", len(removed_media))
+
+        # Now we delete old images associated with the url cache.
+        # These may be cached for a bit on the client (i.e., they
+        # may have a room open with a preview url thing open).
+        # So we wait a couple of days before deleting, just in case.
+        expire_before = now - 2 * 24 * 60 * 60 * 1000
+        media_ids = yield self.store.get_url_cache_media_before(expire_before)
+
+        removed_media = []
+        for media_id in media_ids:
+            fname = self.filepaths.url_cache_filepath(media_id)
+            try:
+                os.remove(fname)
+            except OSError as e:
+                # If the path doesn't exist, meh
+                if e.errno != errno.ENOENT:
+                    logger.warn("Failed to remove media: %r: %s", media_id, e)
+                    continue
+
+            try:
+                dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
+                for dir in dirs:
+                    os.rmdir(dir)
+            except Exception:
+                pass
+
+            thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id)
+            try:
+                shutil.rmtree(thumbnail_dir)
+            except OSError as e:
+                # If the path doesn't exist, meh
+                if e.errno != errno.ENOENT:
+                    logger.warn("Failed to remove media: %r: %s", media_id, e)
+                    continue
+
+            removed_media.append(media_id)
+
+            try:
+                dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id)
+                for dir in dirs:
+                    os.rmdir(dir)
+            except Exception:
+                pass
+
+        yield self.store.delete_url_cache_media(removed_media)
+
+        logger.info("Deleted %d media from url cache", len(removed_media))
+
 
 def decode_and_calc_og(body, media_uri, request_encoding=None):
     from lxml import etree
@@ -424,7 +556,14 @@ def _calc_og(tree, media_uri):
             from lxml import etree
 
             TAGS_TO_REMOVE = (
-                "header", "nav", "aside", "footer", "script", "style", etree.Comment
+                "header",
+                "nav",
+                "aside",
+                "footer",
+                "script",
+                "noscript",
+                "style",
+                etree.Comment
             )
 
             # Split all the text nodes into paragraphs (by splitting on new
@@ -434,6 +573,8 @@ def _calc_og(tree, media_uri):
                 for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
             )
             og['og:description'] = summarize_paragraphs(text_nodes)
+    else:
+        og['og:description'] = summarize_paragraphs([og['og:description']])
 
     # TODO: delete the url downloads to stop diskfilling,
     # as we only ever cared about its OG
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
new file mode 100644
index 0000000000..0252afd9d3
--- /dev/null
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -0,0 +1,145 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer, threads
+
+from .media_storage import FileResponder
+
+from synapse.config._base import Config
+from synapse.util.logcontext import run_in_background
+
+import logging
+import os
+import shutil
+
+
+logger = logging.getLogger(__name__)
+
+
+class StorageProvider(object):
+    """A storage provider is a service that can store uploaded media and
+    retrieve them.
+    """
+    def store_file(self, path, file_info):
+        """Store the file described by file_info. The actual contents can be
+        retrieved by reading the file in file_info.upload_path.
+
+        Args:
+            path (str): Relative path of file in local cache
+            file_info (FileInfo)
+
+        Returns:
+            Deferred
+        """
+        pass
+
+    def fetch(self, path, file_info):
+        """Attempt to fetch the file described by file_info and stream it
+        into writer.
+
+        Args:
+            path (str): Relative path of file in local cache
+            file_info (FileInfo)
+
+        Returns:
+            Deferred(Responder): Returns a Responder if the provider has the file,
+                otherwise returns None.
+        """
+        pass
+
+
+class StorageProviderWrapper(StorageProvider):
+    """Wraps a storage provider and provides various config options
+
+    Args:
+        backend (StorageProvider)
+        store_local (bool): Whether to store new local files or not.
+        store_synchronous (bool): Whether to wait for file to be successfully
+            uploaded, or todo the upload in the backgroud.
+        store_remote (bool): Whether remote media should be uploaded
+    """
+    def __init__(self, backend, store_local, store_synchronous, store_remote):
+        self.backend = backend
+        self.store_local = store_local
+        self.store_synchronous = store_synchronous
+        self.store_remote = store_remote
+
+    def store_file(self, path, file_info):
+        if not file_info.server_name and not self.store_local:
+            return defer.succeed(None)
+
+        if file_info.server_name and not self.store_remote:
+            return defer.succeed(None)
+
+        if self.store_synchronous:
+            return self.backend.store_file(path, file_info)
+        else:
+            # TODO: Handle errors.
+            def store():
+                try:
+                    return self.backend.store_file(path, file_info)
+                except Exception:
+                    logger.exception("Error storing file")
+            run_in_background(store)
+            return defer.succeed(None)
+
+    def fetch(self, path, file_info):
+        return self.backend.fetch(path, file_info)
+
+
+class FileStorageProviderBackend(StorageProvider):
+    """A storage provider that stores files in a directory on a filesystem.
+
+    Args:
+        hs (HomeServer)
+        config: The config returned by `parse_config`.
+    """
+
+    def __init__(self, hs, config):
+        self.cache_directory = hs.config.media_store_path
+        self.base_directory = config
+
+    def store_file(self, path, file_info):
+        """See StorageProvider.store_file"""
+
+        primary_fname = os.path.join(self.cache_directory, path)
+        backup_fname = os.path.join(self.base_directory, path)
+
+        dirname = os.path.dirname(backup_fname)
+        if not os.path.exists(dirname):
+            os.makedirs(dirname)
+
+        return threads.deferToThread(
+            shutil.copyfile, primary_fname, backup_fname,
+        )
+
+    def fetch(self, path, file_info):
+        """See StorageProvider.fetch"""
+
+        backup_fname = os.path.join(self.base_directory, path)
+        if os.path.isfile(backup_fname):
+            return FileResponder(open(backup_fname, "rb"))
+
+    @staticmethod
+    def parse_config(config):
+        """Called on startup to parse config supplied. This should parse
+        the config and raise if there is a problem.
+
+        The returned value is passed into the constructor.
+
+        In this case we only care about a single param, the directory, so let's
+        just pull that out.
+        """
+        return Config.ensure_directory(config["directory"])
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index d8f54adc99..58ada49711 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -14,7 +14,10 @@
 # limitations under the License.
 
 
-from ._base import parse_media_id, respond_404, respond_with_file
+from ._base import (
+    parse_media_id, respond_404, respond_with_file, FileInfo,
+    respond_with_responder,
+)
 from twisted.web.resource import Resource
 from synapse.http.servlet import parse_string, parse_integer
 from synapse.http.server import request_handler, set_cors_headers
@@ -30,12 +33,12 @@ logger = logging.getLogger(__name__)
 class ThumbnailResource(Resource):
     isLeaf = True
 
-    def __init__(self, hs, media_repo):
+    def __init__(self, hs, media_repo, media_storage):
         Resource.__init__(self)
 
         self.store = hs.get_datastore()
-        self.filepaths = media_repo.filepaths
         self.media_repo = media_repo
+        self.media_storage = media_storage
         self.dynamic_thumbnails = hs.config.dynamic_thumbnails
         self.server_name = hs.hostname
         self.version_string = hs.version_string
@@ -64,6 +67,7 @@ class ThumbnailResource(Resource):
                 yield self._respond_local_thumbnail(
                     request, media_id, width, height, method, m_type
                 )
+            self.media_repo.mark_recently_accessed(None, media_id)
         else:
             if self.dynamic_thumbnails:
                 yield self._select_or_generate_remote_thumbnail(
@@ -75,6 +79,7 @@ class ThumbnailResource(Resource):
                     request, server_name, media_id,
                     width, height, method, m_type
                 )
+            self.media_repo.mark_recently_accessed(server_name, media_id)
 
     @defer.inlineCallbacks
     def _respond_local_thumbnail(self, request, media_id, width, height,
@@ -84,11 +89,10 @@ class ThumbnailResource(Resource):
         if not media_info:
             respond_404(request)
             return
-
-        # if media_info["media_type"] == "image/svg+xml":
-        #     file_path = self.filepaths.local_media_filepath(media_id)
-        #     yield respond_with_file(request, media_info["media_type"], file_path)
-        #     return
+        if media_info["quarantined_by"]:
+            logger.info("Media is quarantined")
+            respond_404(request)
+            return
 
         thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
 
@@ -96,20 +100,25 @@ class ThumbnailResource(Resource):
             thumbnail_info = self._select_thumbnail(
                 width, height, method, m_type, thumbnail_infos
             )
-            t_width = thumbnail_info["thumbnail_width"]
-            t_height = thumbnail_info["thumbnail_height"]
-            t_type = thumbnail_info["thumbnail_type"]
-            t_method = thumbnail_info["thumbnail_method"]
 
-            file_path = self.filepaths.local_media_thumbnail(
-                media_id, t_width, t_height, t_type, t_method,
+            file_info = FileInfo(
+                server_name=None, file_id=media_id,
+                url_cache=media_info["url_cache"],
+                thumbnail=True,
+                thumbnail_width=thumbnail_info["thumbnail_width"],
+                thumbnail_height=thumbnail_info["thumbnail_height"],
+                thumbnail_type=thumbnail_info["thumbnail_type"],
+                thumbnail_method=thumbnail_info["thumbnail_method"],
             )
-            yield respond_with_file(request, t_type, file_path)
 
+            t_type = file_info.thumbnail_type
+            t_length = thumbnail_info["thumbnail_length"]
+
+            responder = yield self.media_storage.fetch_media(file_info)
+            yield respond_with_responder(request, responder, t_type, t_length)
         else:
-            yield self._respond_default_thumbnail(
-                request, media_info, width, height, method, m_type,
-            )
+            logger.info("Couldn't find any generated thumbnails")
+            respond_404(request)
 
     @defer.inlineCallbacks
     def _select_or_generate_local_thumbnail(self, request, media_id, desired_width,
@@ -120,11 +129,10 @@ class ThumbnailResource(Resource):
         if not media_info:
             respond_404(request)
             return
-
-        # if media_info["media_type"] == "image/svg+xml":
-        #     file_path = self.filepaths.local_media_filepath(media_id)
-        #     yield respond_with_file(request, media_info["media_type"], file_path)
-        #     return
+        if media_info["quarantined_by"]:
+            logger.info("Media is quarantined")
+            respond_404(request)
+            return
 
         thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
         for info in thumbnail_infos:
@@ -134,37 +142,43 @@ class ThumbnailResource(Resource):
             t_type = info["thumbnail_type"] == desired_type
 
             if t_w and t_h and t_method and t_type:
-                file_path = self.filepaths.local_media_thumbnail(
-                    media_id, desired_width, desired_height, desired_type, desired_method,
+                file_info = FileInfo(
+                    server_name=None, file_id=media_id,
+                    url_cache=media_info["url_cache"],
+                    thumbnail=True,
+                    thumbnail_width=info["thumbnail_width"],
+                    thumbnail_height=info["thumbnail_height"],
+                    thumbnail_type=info["thumbnail_type"],
+                    thumbnail_method=info["thumbnail_method"],
                 )
-                yield respond_with_file(request, desired_type, file_path)
-                return
 
-        logger.debug("We don't have a local thumbnail of that size. Generating")
+                t_type = file_info.thumbnail_type
+                t_length = info["thumbnail_length"]
+
+                responder = yield self.media_storage.fetch_media(file_info)
+                if responder:
+                    yield respond_with_responder(request, responder, t_type, t_length)
+                    return
+
+        logger.debug("We don't have a thumbnail of that size. Generating")
 
         # Okay, so we generate one.
         file_path = yield self.media_repo.generate_local_exact_thumbnail(
-            media_id, desired_width, desired_height, desired_method, desired_type
+            media_id, desired_width, desired_height, desired_method, desired_type,
+            url_cache=media_info["url_cache"],
         )
 
         if file_path:
             yield respond_with_file(request, desired_type, file_path)
         else:
-            yield self._respond_default_thumbnail(
-                request, media_info, desired_width, desired_height,
-                desired_method, desired_type,
-            )
+            logger.warn("Failed to generate thumbnail")
+            respond_404(request)
 
     @defer.inlineCallbacks
     def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
                                              desired_width, desired_height,
                                              desired_method, desired_type):
-        media_info = yield self.media_repo.get_remote_media(server_name, media_id)
-
-        # if media_info["media_type"] == "image/svg+xml":
-        #     file_path = self.filepaths.remote_media_filepath(server_name, media_id)
-        #     yield respond_with_file(request, media_info["media_type"], file_path)
-        #     return
+        media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
 
         thumbnail_infos = yield self.store.get_remote_media_thumbnails(
             server_name, media_id,
@@ -179,14 +193,24 @@ class ThumbnailResource(Resource):
             t_type = info["thumbnail_type"] == desired_type
 
             if t_w and t_h and t_method and t_type:
-                file_path = self.filepaths.remote_media_thumbnail(
-                    server_name, file_id, desired_width, desired_height,
-                    desired_type, desired_method,
+                file_info = FileInfo(
+                    server_name=server_name, file_id=media_info["filesystem_id"],
+                    thumbnail=True,
+                    thumbnail_width=info["thumbnail_width"],
+                    thumbnail_height=info["thumbnail_height"],
+                    thumbnail_type=info["thumbnail_type"],
+                    thumbnail_method=info["thumbnail_method"],
                 )
-                yield respond_with_file(request, desired_type, file_path)
-                return
 
-        logger.debug("We don't have a local thumbnail of that size. Generating")
+                t_type = file_info.thumbnail_type
+                t_length = info["thumbnail_length"]
+
+                responder = yield self.media_storage.fetch_media(file_info)
+                if responder:
+                    yield respond_with_responder(request, responder, t_type, t_length)
+                    return
+
+        logger.debug("We don't have a thumbnail of that size. Generating")
 
         # Okay, so we generate one.
         file_path = yield self.media_repo.generate_remote_exact_thumbnail(
@@ -197,22 +221,16 @@ class ThumbnailResource(Resource):
         if file_path:
             yield respond_with_file(request, desired_type, file_path)
         else:
-            yield self._respond_default_thumbnail(
-                request, media_info, desired_width, desired_height,
-                desired_method, desired_type,
-            )
+            logger.warn("Failed to generate thumbnail")
+            respond_404(request)
 
     @defer.inlineCallbacks
     def _respond_remote_thumbnail(self, request, server_name, media_id, width,
                                   height, method, m_type):
         # TODO: Don't download the whole remote file
-        # We should proxy the thumbnail from the remote server instead.
-        media_info = yield self.media_repo.get_remote_media(server_name, media_id)
-
-        # if media_info["media_type"] == "image/svg+xml":
-        #     file_path = self.filepaths.remote_media_filepath(server_name, media_id)
-        #     yield respond_with_file(request, media_info["media_type"], file_path)
-        #     return
+        # We should proxy the thumbnail from the remote server instead of
+        # downloading the remote file and generating our own thumbnails.
+        media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
 
         thumbnail_infos = yield self.store.get_remote_media_thumbnails(
             server_name, media_id,
@@ -222,59 +240,23 @@ class ThumbnailResource(Resource):
             thumbnail_info = self._select_thumbnail(
                 width, height, method, m_type, thumbnail_infos
             )
-            t_width = thumbnail_info["thumbnail_width"]
-            t_height = thumbnail_info["thumbnail_height"]
-            t_type = thumbnail_info["thumbnail_type"]
-            t_method = thumbnail_info["thumbnail_method"]
-            file_id = thumbnail_info["filesystem_id"]
+            file_info = FileInfo(
+                server_name=server_name, file_id=media_info["filesystem_id"],
+                thumbnail=True,
+                thumbnail_width=thumbnail_info["thumbnail_width"],
+                thumbnail_height=thumbnail_info["thumbnail_height"],
+                thumbnail_type=thumbnail_info["thumbnail_type"],
+                thumbnail_method=thumbnail_info["thumbnail_method"],
+            )
+
+            t_type = file_info.thumbnail_type
             t_length = thumbnail_info["thumbnail_length"]
 
-            file_path = self.filepaths.remote_media_thumbnail(
-                server_name, file_id, t_width, t_height, t_type, t_method,
-            )
-            yield respond_with_file(request, t_type, file_path, t_length)
+            responder = yield self.media_storage.fetch_media(file_info)
+            yield respond_with_responder(request, responder, t_type, t_length)
         else:
-            yield self._respond_default_thumbnail(
-                request, media_info, width, height, method, m_type,
-            )
-
-    @defer.inlineCallbacks
-    def _respond_default_thumbnail(self, request, media_info, width, height,
-                                   method, m_type):
-        # XXX: how is this meant to work? store.get_default_thumbnails
-        # appears to always return [] so won't this always 404?
-        media_type = media_info["media_type"]
-        top_level_type = media_type.split("/")[0]
-        sub_type = media_type.split("/")[-1].split(";")[0]
-        thumbnail_infos = yield self.store.get_default_thumbnails(
-            top_level_type, sub_type,
-        )
-        if not thumbnail_infos:
-            thumbnail_infos = yield self.store.get_default_thumbnails(
-                top_level_type, "_default",
-            )
-        if not thumbnail_infos:
-            thumbnail_infos = yield self.store.get_default_thumbnails(
-                "_default", "_default",
-            )
-        if not thumbnail_infos:
+            logger.info("Failed to find any generated thumbnails")
             respond_404(request)
-            return
-
-        thumbnail_info = self._select_thumbnail(
-            width, height, "crop", m_type, thumbnail_infos
-        )
-
-        t_width = thumbnail_info["thumbnail_width"]
-        t_height = thumbnail_info["thumbnail_height"]
-        t_type = thumbnail_info["thumbnail_type"]
-        t_method = thumbnail_info["thumbnail_method"]
-        t_length = thumbnail_info["thumbnail_length"]
-
-        file_path = self.filepaths.default_thumbnail(
-            top_level_type, sub_type, t_width, t_height, t_type, t_method,
-        )
-        yield respond_with_file(request, t_type, file_path, t_length)
 
     def _select_thumbnail(self, desired_width, desired_height, desired_method,
                           desired_type, thumbnail_infos):
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 3868d4f65f..e1ee535b9a 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -50,12 +50,16 @@ class Thumbnailer(object):
         else:
             return ((max_height * self.width) // self.height, max_height)
 
-    def scale(self, output_path, width, height, output_type):
-        """Rescales the image to the given dimensions"""
+    def scale(self, width, height, output_type):
+        """Rescales the image to the given dimensions.
+
+        Returns:
+            BytesIO: the bytes of the encoded image ready to be written to disk
+        """
         scaled = self.image.resize((width, height), Image.ANTIALIAS)
-        return self.save_image(scaled, output_type, output_path)
+        return self._encode_image(scaled, output_type)
 
-    def crop(self, output_path, width, height, output_type):
+    def crop(self, width, height, output_type):
         """Rescales and crops the image to the given dimensions preserving
         aspect::
             (w_in / h_in) = (w_scaled / h_scaled)
@@ -65,6 +69,9 @@ class Thumbnailer(object):
         Args:
             max_width: The largest possible width.
             max_height: The larget possible height.
+
+        Returns:
+            BytesIO: the bytes of the encoded image ready to be written to disk
         """
         if width * self.height > height * self.width:
             scaled_height = (width * self.height) // self.width
@@ -82,13 +89,9 @@ class Thumbnailer(object):
             crop_left = (scaled_width - width) // 2
             crop_right = width + crop_left
             cropped = scaled_image.crop((crop_left, 0, crop_right, height))
-        return self.save_image(cropped, output_type, output_path)
+        return self._encode_image(cropped, output_type)
 
-    def save_image(self, output_image, output_type, output_path):
+    def _encode_image(self, output_image, output_type):
         output_bytes_io = BytesIO()
         output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80)
-        output_bytes = output_bytes_io.getvalue()
-        with open(output_path, "wb") as output_file:
-            output_file.write(output_bytes)
-        logger.info("Stored thumbnail in file %r", output_path)
-        return len(output_bytes)
+        return output_bytes_io
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 4ab33f73bf..a31e75cb46 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -81,19 +81,19 @@ class UploadResource(Resource):
         headers = request.requestHeaders
 
         if headers.hasHeader("Content-Type"):
-            media_type = headers.getRawHeaders("Content-Type")[0]
+            media_type = headers.getRawHeaders(b"Content-Type")[0]
         else:
             raise SynapseError(
                 msg="Upload request missing 'Content-Type'",
                 code=400,
             )
 
-        # if headers.hasHeader("Content-Disposition"):
-        #     disposition = headers.getRawHeaders("Content-Disposition")[0]
+        # if headers.hasHeader(b"Content-Disposition"):
+        #     disposition = headers.getRawHeaders(b"Content-Disposition")[0]
         # TODO(markjh): parse content-dispostion
 
         content_uri = yield self.media_repo.create_content(
-            media_type, upload_name, request.content.read(),
+            media_type, upload_name, request.content,
             content_length, requester.user
         )
 
diff --git a/synapse/server.py b/synapse/server.py
index c577032041..ebdea6b0c4 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -31,29 +31,47 @@ from synapse.appservice.api import ApplicationServiceApi
 from synapse.appservice.scheduler import ApplicationServiceScheduler
 from synapse.crypto.keyring import Keyring
 from synapse.events.builder import EventBuilderFactory
-from synapse.federation import initialize_http_replication
+from synapse.events.spamcheck import SpamChecker
+from synapse.federation.federation_client import FederationClient
+from synapse.federation.federation_server import FederationServer
 from synapse.federation.send_queue import FederationRemoteSendQueue
+from synapse.federation.federation_server import FederationHandlerRegistry
 from synapse.federation.transport.client import TransportLayerClient
 from synapse.federation.transaction_queue import TransactionQueue
 from synapse.handlers import Handlers
 from synapse.handlers.appservice import ApplicationServicesHandler
 from synapse.handlers.auth import AuthHandler, MacaroonGeneartor
+from synapse.handlers.deactivate_account import DeactivateAccountHandler
 from synapse.handlers.devicemessage import DeviceMessageHandler
 from synapse.handlers.device import DeviceHandler
 from synapse.handlers.e2e_keys import E2eKeysHandler
 from synapse.handlers.presence import PresenceHandler
 from synapse.handlers.room_list import RoomListHandler
+from synapse.handlers.room_member import RoomMemberMasterHandler
+from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
+from synapse.handlers.set_password import SetPasswordHandler
 from synapse.handlers.sync import SyncHandler
 from synapse.handlers.typing import TypingHandler
 from synapse.handlers.events import EventHandler, EventStreamHandler
 from synapse.handlers.initial_sync import InitialSyncHandler
 from synapse.handlers.receipts import ReceiptsHandler
+from synapse.handlers.read_marker import ReadMarkerHandler
+from synapse.handlers.user_directory import UserDirectoryHandler
+from synapse.handlers.groups_local import GroupsLocalHandler
+from synapse.handlers.profile import ProfileHandler
+from synapse.handlers.message import EventCreationHandler
+from synapse.groups.groups_server import GroupsServerHandler
+from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning
 from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
 from synapse.http.matrixfederationclient import MatrixFederationHttpClient
 from synapse.notifier import Notifier
+from synapse.push.action_generator import ActionGenerator
 from synapse.push.pusherpool import PusherPool
-from synapse.rest.media.v1.media_repository import MediaRepository
-from synapse.state import StateHandler
+from synapse.rest.media.v1.media_repository import (
+    MediaRepository,
+    MediaRepositoryResource,
+)
+from synapse.state import StateHandler, StateResolutionHandler
 from synapse.storage import DataStore
 from synapse.streams.events import EventSources
 from synapse.util import Clock
@@ -82,18 +100,14 @@ class HomeServer(object):
     """
 
     DEPENDENCIES = [
-        'config',
-        'clock',
         'http_client',
         'db_pool',
-        'persistence_service',
-        'replication_layer',
-        'datastore',
+        'federation_client',
+        'federation_server',
         'handlers',
-        'v1auth',
         'auth',
-        'rest_servlet_factory',
         'state_handler',
+        'state_resolution_handler',
         'presence_handler',
         'sync_handler',
         'typing_handler',
@@ -108,19 +122,12 @@ class HomeServer(object):
         'application_service_scheduler',
         'application_service_handler',
         'device_message_handler',
+        'profile_handler',
+        'event_creation_handler',
+        'deactivate_account_handler',
+        'set_password_handler',
         'notifier',
-        'distributor',
-        'client_resource',
-        'resource_for_federation',
-        'resource_for_static_content',
-        'resource_for_web_client',
-        'resource_for_content_repo',
-        'resource_for_server_key',
-        'resource_for_server_key_v2',
-        'resource_for_media_repository',
-        'resource_for_metrics',
         'event_sources',
-        'ratelimiter',
         'keyring',
         'pusherpool',
         'event_builder_factory',
@@ -128,10 +135,22 @@ class HomeServer(object):
         'http_client_context_factory',
         'simple_http_client',
         'media_repository',
+        'media_repository_resource',
         'federation_transport_client',
         'federation_sender',
         'receipts_handler',
         'macaroon_generator',
+        'tcp_replication',
+        'read_marker_handler',
+        'action_generator',
+        'user_directory_handler',
+        'groups_local_handler',
+        'groups_server_handler',
+        'groups_attestation_signing',
+        'groups_attestation_renewer',
+        'spam_checker',
+        'room_member_handler',
+        'federation_registry',
     ]
 
     def __init__(self, hostname, **kwargs):
@@ -165,8 +184,26 @@ class HomeServer(object):
     def is_mine_id(self, string):
         return string.split(":", 1)[1] == self.hostname
 
-    def build_replication_layer(self):
-        return initialize_http_replication(self)
+    def get_clock(self):
+        return self.clock
+
+    def get_datastore(self):
+        return self.datastore
+
+    def get_config(self):
+        return self.config
+
+    def get_distributor(self):
+        return self.distributor
+
+    def get_ratelimiter(self):
+        return self.ratelimiter
+
+    def build_federation_client(self):
+        return FederationClient(self)
+
+    def build_federation_server(self):
+        return FederationServer(self)
 
     def build_handlers(self):
         return Handlers(self)
@@ -187,18 +224,12 @@ class HomeServer(object):
     def build_simple_http_client(self):
         return SimpleHttpClient(self)
 
-    def build_v1auth(self):
-        orf = Auth(self)
-        # Matrix spec makes no reference to what HTTP status code is returned,
-        # but the V1 API uses 403 where it means 401, and the webclient
-        # relies on this behaviour, so V1 gets its own copy of the auth
-        # with backwards compat behaviour.
-        orf.TOKEN_NOT_FOUND_HTTP_STATUS = 403
-        return orf
-
     def build_state_handler(self):
         return StateHandler(self)
 
+    def build_state_resolution_handler(self):
+        return StateResolutionHandler(self)
+
     def build_presence_handler(self):
         return PresenceHandler(self)
 
@@ -244,6 +275,18 @@ class HomeServer(object):
     def build_initial_sync_handler(self):
         return InitialSyncHandler(self)
 
+    def build_profile_handler(self):
+        return ProfileHandler(self)
+
+    def build_event_creation_handler(self):
+        return EventCreationHandler(self)
+
+    def build_deactivate_account_handler(self):
+        return DeactivateAccountHandler(self)
+
+    def build_set_password_handler(self):
+        return SetPasswordHandler(self)
+
     def build_event_sources(self):
         return EventSources(self)
 
@@ -273,6 +316,28 @@ class HomeServer(object):
             **self.db_config.get("args", {})
         )
 
+    def get_db_conn(self, run_new_connection=True):
+        """Makes a new connection to the database, skipping the db pool
+
+        Returns:
+            Connection: a connection object implementing the PEP-249 spec
+        """
+        # Any param beginning with cp_ is a parameter for adbapi, and should
+        # not be passed to the database engine.
+        db_params = {
+            k: v for k, v in self.db_config.get("args", {}).items()
+            if not k.startswith("cp_")
+        }
+        db_conn = self.database_engine.module.connect(**db_params)
+        if run_new_connection:
+            self.database_engine.on_new_connection(db_conn)
+        return db_conn
+
+    def build_media_repository_resource(self):
+        # build the media repo resource. This indirects through the HomeServer
+        # to ensure that we only have a single instance of
+        return MediaRepositoryResource(self)
+
     def build_media_repository(self):
         return MediaRepository(self)
 
@@ -290,6 +355,41 @@ class HomeServer(object):
     def build_receipts_handler(self):
         return ReceiptsHandler(self)
 
+    def build_read_marker_handler(self):
+        return ReadMarkerHandler(self)
+
+    def build_tcp_replication(self):
+        raise NotImplementedError()
+
+    def build_action_generator(self):
+        return ActionGenerator(self)
+
+    def build_user_directory_handler(self):
+        return UserDirectoryHandler(self)
+
+    def build_groups_local_handler(self):
+        return GroupsLocalHandler(self)
+
+    def build_groups_server_handler(self):
+        return GroupsServerHandler(self)
+
+    def build_groups_attestation_signing(self):
+        return GroupAttestationSigning(self)
+
+    def build_groups_attestation_renewer(self):
+        return GroupAttestionRenewer(self)
+
+    def build_spam_checker(self):
+        return SpamChecker(self)
+
+    def build_room_member_handler(self):
+        if self.config.worker_app:
+            return RoomMemberWorkerHandler(self)
+        return RoomMemberMasterHandler(self)
+
+    def build_federation_registry(self):
+        return FederationHandlerRegistry()
+
     def remove_pusher(self, app_id, push_key, user_id):
         return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
 
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 9570df5537..c3a9a3847b 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -1,10 +1,16 @@
 import synapse.api.auth
+import synapse.federation.transaction_queue
+import synapse.federation.transport.client
 import synapse.handlers
 import synapse.handlers.auth
+import synapse.handlers.deactivate_account
 import synapse.handlers.device
 import synapse.handlers.e2e_keys
-import synapse.storage
+import synapse.handlers.set_password
+import synapse.rest.media.v1.media_repository
 import synapse.state
+import synapse.storage
+
 
 class HomeServer(object):
     def get_auth(self) -> synapse.api.auth.Auth:
@@ -27,3 +33,24 @@ class HomeServer(object):
 
     def get_state_handler(self) -> synapse.state.StateHandler:
         pass
+
+    def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
+        pass
+
+    def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
+        pass
+
+    def get_set_password_handler(self) -> synapse.handlers.set_password.SetPasswordHandler:
+        pass
+
+    def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue:
+        pass
+
+    def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient:
+        pass
+
+    def get_media_repository_resource(self) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource:
+        pass
+
+    def get_media_repository(self) -> synapse.rest.media.v1.media_repository.MediaRepository:
+        pass
diff --git a/synapse/state.py b/synapse/state.py
index 383d32b163..26093c8434 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -24,13 +24,13 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.events.snapshot import EventContext
 from synapse.util.async import Linearizer
+from synapse.util.caches import CACHE_SIZE_FACTOR
 
 from collections import namedtuple
 from frozendict import frozendict
 
 import logging
 import hashlib
-import os
 
 logger = logging.getLogger(__name__)
 
@@ -38,9 +38,6 @@ logger = logging.getLogger(__name__)
 KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
 
 
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-
-
 SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR)
 EVICTION_TIMEOUT_SECONDS = 60 * 60
 
@@ -61,7 +58,11 @@ class _StateCacheEntry(object):
     __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
 
     def __init__(self, state, state_group, prev_group=None, delta_ids=None):
+        # dict[(str, str), str] map  from (type, state_key) to event_id
         self.state = frozendict(state)
+
+        # the ID of a state group if one and only one is involved.
+        # otherwise, None otherwise?
         self.state_group = state_group
 
         self.prev_group = prev_group
@@ -84,31 +85,19 @@ class _StateCacheEntry(object):
 
 
 class StateHandler(object):
-    """ Responsible for doing state conflict resolution.
+    """Fetches bits of state from the stores, and does state resolution
+    where necessary
     """
 
     def __init__(self, hs):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.hs = hs
-
-        # dict of set of event_ids -> _StateCacheEntry.
-        self._state_cache = None
-        self.resolve_linearizer = Linearizer(name="state_resolve_lock")
+        self._state_resolution_handler = hs.get_state_resolution_handler()
 
     def start_caching(self):
-        logger.debug("start_caching")
-
-        self._state_cache = ExpiringCache(
-            cache_name="state_cache",
-            clock=self.clock,
-            max_len=SIZE_OF_CACHE,
-            expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
-            iterable=True,
-            reset_expiry_on_get=True,
-        )
-
-        self._state_cache.start()
+        # TODO: remove this shim
+        self._state_resolution_handler.start_caching()
 
     @defer.inlineCallbacks
     def get_current_state(self, room_id, event_type=None, state_key="",
@@ -130,7 +119,7 @@ class StateHandler(object):
             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
 
         logger.debug("calling resolve_state_groups from get_current_state")
-        ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+        ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
         state = ret.state
 
         if event_type:
@@ -143,25 +132,33 @@ class StateHandler(object):
 
         state_map = yield self.store.get_events(state.values(), get_prev_content=False)
         state = {
-            key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
+            key: state_map[e_id] for key, e_id in state.iteritems() if e_id in state_map
         }
 
         defer.returnValue(state)
 
     @defer.inlineCallbacks
-    def get_current_state_ids(self, room_id, event_type=None, state_key="",
-                              latest_event_ids=None):
+    def get_current_state_ids(self, room_id, latest_event_ids=None):
+        """Get the current state, or the state at a set of events, for a room
+
+        Args:
+            room_id (str):
+
+            latest_event_ids (iterable[str]|None): if given, the forward
+                extremities to resolve. If None, we look them up from the
+                database (via a cache)
+
+        Returns:
+            Deferred[dict[(str, str), str)]]: the state dict, mapping from
+                (event_type, state_key) -> event_id
+        """
         if not latest_event_ids:
             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
 
         logger.debug("calling resolve_state_groups from get_current_state_ids")
-        ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+        ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
         state = ret.state
 
-        if event_type:
-            defer.returnValue(state.get((event_type, state_key)))
-            return
-
         defer.returnValue(state)
 
     @defer.inlineCallbacks
@@ -169,54 +166,71 @@ class StateHandler(object):
         if not latest_event_ids:
             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
         logger.debug("calling resolve_state_groups from get_current_user_in_room")
-        entry = yield self.resolve_state_groups(room_id, latest_event_ids)
-        joined_users = yield self.store.get_joined_users_from_state(
-            room_id, entry.state_id, entry.state
-        )
+        entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
+        joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
         defer.returnValue(joined_users)
 
     @defer.inlineCallbacks
+    def get_current_hosts_in_room(self, room_id, latest_event_ids=None):
+        if not latest_event_ids:
+            latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+        logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
+        entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
+        joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
+        defer.returnValue(joined_hosts)
+
+    @defer.inlineCallbacks
     def compute_event_context(self, event, old_state=None):
-        """ Fills out the context with the `current state` of the graph. The
-        `current state` here is defined to be the state of the event graph
-        just before the event - i.e. it never includes `event`
+        """Build an EventContext structure for the event.
 
-        If `event` has `auth_events` then this will also fill out the
-        `auth_events` field on `context` from the `current_state`.
+        This works out what the current state should be for the event, and
+        generates a new state group if necessary.
 
         Args:
-            event (EventBase)
+            event (synapse.events.EventBase):
+            old_state (dict|None): The state at the event if it can't be
+                calculated from existing events. This is normally only specified
+                when receiving an event from federation where we don't have the
+                prev events for, e.g. when backfilling.
         Returns:
-            an EventContext
+            synapse.events.snapshot.EventContext:
         """
-        context = EventContext()
 
         if event.internal_metadata.is_outlier():
             # If this is an outlier, then we know it shouldn't have any current
             # state. Certainly store.get_current_state won't return any, and
             # persisting the event won't store the state group.
+            context = EventContext()
             if old_state:
                 context.prev_state_ids = {
                     (s.type, s.state_key): s.event_id for s in old_state
                 }
                 if event.is_state():
-                    context.current_state_events = dict(context.prev_state_ids)
+                    context.current_state_ids = dict(context.prev_state_ids)
                     key = (event.type, event.state_key)
-                    context.current_state_events[key] = event.event_id
+                    context.current_state_ids[key] = event.event_id
                 else:
-                    context.current_state_events = context.prev_state_ids
+                    context.current_state_ids = context.prev_state_ids
             else:
                 context.current_state_ids = {}
                 context.prev_state_ids = {}
             context.prev_state_events = []
-            context.state_group = self.store.get_next_state_group()
+
+            # We don't store state for outliers, so we don't generate a state
+            # froup for it.
+            context.state_group = None
+
             defer.returnValue(context)
 
         if old_state:
+            # We already have the state, so we don't need to calculate it.
+            # Let's just correctly fill out the context and create a
+            # new state group for it.
+
+            context = EventContext()
             context.prev_state_ids = {
                 (s.type, s.state_key): s.event_id for s in old_state
             }
-            context.state_group = self.store.get_next_state_group()
 
             if event.is_state():
                 key = (event.type, event.state_key)
@@ -229,26 +243,29 @@ class StateHandler(object):
             else:
                 context.current_state_ids = context.prev_state_ids
 
+            context.state_group = yield self.store.store_state_group(
+                event.event_id,
+                event.room_id,
+                prev_group=None,
+                delta_ids=None,
+                current_state_ids=context.current_state_ids,
+            )
+
             context.prev_state_events = []
             defer.returnValue(context)
 
         logger.debug("calling resolve_state_groups from compute_event_context")
-        if event.is_state():
-            entry = yield self.resolve_state_groups(
-                event.room_id, [e for e, _ in event.prev_events],
-                event_type=event.type,
-                state_key=event.state_key,
-            )
-        else:
-            entry = yield self.resolve_state_groups(
-                event.room_id, [e for e, _ in event.prev_events],
-            )
+        entry = yield self.resolve_state_groups_for_events(
+            event.room_id, [e for e, _ in event.prev_events],
+        )
 
         curr_state = entry.state
 
+        context = EventContext()
         context.prev_state_ids = curr_state
         if event.is_state():
-            context.state_group = self.store.get_next_state_group()
+            # If this is a state event then we need to create a new state
+            # group for the state after this event.
 
             key = (event.type, event.state_key)
             if key in context.prev_state_ids:
@@ -258,58 +275,176 @@ class StateHandler(object):
             context.current_state_ids = dict(context.prev_state_ids)
             context.current_state_ids[key] = event.event_id
 
-            context.prev_group = entry.prev_group
-            context.delta_ids = entry.delta_ids
-            if context.delta_ids is not None:
-                context.delta_ids = dict(context.delta_ids)
+            if entry.state_group:
+                # If the state at the event has a state group assigned then
+                # we can use that as the prev group
+                context.prev_group = entry.state_group
+                context.delta_ids = {
+                    key: event.event_id
+                }
+            elif entry.prev_group:
+                # If the state at the event only has a prev group, then we can
+                # use that as a prev group too.
+                context.prev_group = entry.prev_group
+                context.delta_ids = dict(entry.delta_ids)
                 context.delta_ids[key] = event.event_id
+
+            context.state_group = yield self.store.store_state_group(
+                event.event_id,
+                event.room_id,
+                prev_group=context.prev_group,
+                delta_ids=context.delta_ids,
+                current_state_ids=context.current_state_ids,
+            )
         else:
+            context.current_state_ids = context.prev_state_ids
+            context.prev_group = entry.prev_group
+            context.delta_ids = entry.delta_ids
+
             if entry.state_group is None:
-                entry.state_group = self.store.get_next_state_group()
+                entry.state_group = yield self.store.store_state_group(
+                    event.event_id,
+                    event.room_id,
+                    prev_group=entry.prev_group,
+                    delta_ids=entry.delta_ids,
+                    current_state_ids=context.current_state_ids,
+                )
                 entry.state_id = entry.state_group
 
             context.state_group = entry.state_group
-            context.current_state_ids = context.prev_state_ids
-            context.prev_group = entry.prev_group
-            context.delta_ids = entry.delta_ids
 
         context.prev_state_events = []
         defer.returnValue(context)
 
     @defer.inlineCallbacks
-    @log_function
-    def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""):
+    def resolve_state_groups_for_events(self, room_id, event_ids):
         """ Given a list of event_ids this method fetches the state at each
         event, resolves conflicts between them and returns them.
 
+        Args:
+            room_id (str):
+            event_ids (list[str]):
+
         Returns:
-            a Deferred tuple of (`state_group`, `state`, `prev_state`).
-            `state_group` is the name of a state group if one and only one is
-            involved. `state` is a map from (type, state_key) to event, and
-            `prev_state` is a list of event ids.
+            Deferred[_StateCacheEntry]: resolved state
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
+        # map from state group id to the state in that state group (where
+        # 'state' is a map from state key to event id)
+        # dict[int, dict[(str, str), str]]
         state_groups_ids = yield self.store.get_state_groups_ids(
             room_id, event_ids
         )
 
-        logger.debug(
-            "resolve_state_groups state_groups %s",
-            state_groups_ids.keys()
-        )
-
-        group_names = frozenset(state_groups_ids.keys())
-        if len(group_names) == 1:
+        if len(state_groups_ids) == 1:
             name, state_list = state_groups_ids.items().pop()
 
+            prev_group, delta_ids = yield self.store.get_state_group_delta(name)
+
             defer.returnValue(_StateCacheEntry(
                 state=state_list,
                 state_group=name,
-                prev_group=name,
-                delta_ids={},
+                prev_group=prev_group,
+                delta_ids=delta_ids,
             ))
 
+        result = yield self._state_resolution_handler.resolve_state_groups(
+            room_id, state_groups_ids, None, self._state_map_factory,
+        )
+        defer.returnValue(result)
+
+    def _state_map_factory(self, ev_ids):
+        return self.store.get_events(
+            ev_ids, get_prev_content=False, check_redacted=False,
+        )
+
+    def resolve_events(self, state_sets, event):
+        logger.info(
+            "Resolving state for %s with %d groups", event.room_id, len(state_sets)
+        )
+        state_set_ids = [{
+            (ev.type, ev.state_key): ev.event_id
+            for ev in st
+        } for st in state_sets]
+
+        state_map = {
+            ev.event_id: ev
+            for st in state_sets
+            for ev in st
+        }
+
+        with Measure(self.clock, "state._resolve_events"):
+            new_state = resolve_events_with_state_map(state_set_ids, state_map)
+
+        new_state = {
+            key: state_map[ev_id] for key, ev_id in new_state.iteritems()
+        }
+
+        return new_state
+
+
+class StateResolutionHandler(object):
+    """Responsible for doing state conflict resolution.
+
+    Note that the storage layer depends on this handler, so all functions must
+    be storage-independent.
+    """
+    def __init__(self, hs):
+        self.clock = hs.get_clock()
+
+        # dict of set of event_ids -> _StateCacheEntry.
+        self._state_cache = None
+        self.resolve_linearizer = Linearizer(name="state_resolve_lock")
+
+    def start_caching(self):
+        logger.debug("start_caching")
+
+        self._state_cache = ExpiringCache(
+            cache_name="state_cache",
+            clock=self.clock,
+            max_len=SIZE_OF_CACHE,
+            expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
+            iterable=True,
+            reset_expiry_on_get=True,
+        )
+
+        self._state_cache.start()
+
+    @defer.inlineCallbacks
+    @log_function
+    def resolve_state_groups(
+        self, room_id, state_groups_ids, event_map, state_map_factory,
+    ):
+        """Resolves conflicts between a set of state groups
+
+        Always generates a new state group (unless we hit the cache), so should
+        not be called for a single state group
+
+        Args:
+            room_id (str): room we are resolving for (used for logging)
+            state_groups_ids (dict[int, dict[(str, str), str]]):
+                 map from state group id to the state in that state group
+                (where 'state' is a map from state key to event id)
+
+            event_map(dict[str,FrozenEvent]|None):
+                a dict from event_id to event, for any events that we happen to
+                have in flight (eg, those currently being persisted). This will be
+                used as a starting point fof finding the state we need; any missing
+                events will be requested via state_map_factory.
+
+                If None, all events will be fetched via state_map_factory.
+
+        Returns:
+            Deferred[_StateCacheEntry]: resolved state
+        """
+        logger.debug(
+            "resolve_state_groups state_groups %s",
+            state_groups_ids.keys()
+        )
+
+        group_names = frozenset(state_groups_ids.keys())
+
         with (yield self.resolve_linearizer.queue(group_names)):
             if self._state_cache is not None:
                 cache = self._state_cache.get(group_names, None)
@@ -320,56 +455,62 @@ class StateHandler(object):
                 "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
             )
 
+            # build a map from state key to the event_ids which set that state.
+            # dict[(str, str), set[str])
             state = {}
-            for st in state_groups_ids.values():
-                for key, e_id in st.items():
+            for st in state_groups_ids.itervalues():
+                for key, e_id in st.iteritems():
                     state.setdefault(key, set()).add(e_id)
 
+            # build a map from state key to the event_ids which set that state,
+            # including only those where there are state keys in conflict.
             conflicted_state = {
                 k: list(v)
-                for k, v in state.items()
+                for k, v in state.iteritems()
                 if len(v) > 1
             }
 
             if conflicted_state:
                 logger.info("Resolving conflicted state for %r", room_id)
                 with Measure(self.clock, "state._resolve_events"):
-                    new_state = yield resolve_events(
+                    new_state = yield resolve_events_with_factory(
                         state_groups_ids.values(),
-                        state_map_factory=lambda ev_ids: self.store.get_events(
-                            ev_ids, get_prev_content=False, check_redacted=False,
-                        ),
+                        event_map=event_map,
+                        state_map_factory=state_map_factory,
                     )
             else:
                 new_state = {
-                    key: e_ids.pop() for key, e_ids in state.items()
+                    key: e_ids.pop() for key, e_ids in state.iteritems()
                 }
 
-            state_group = None
-            new_state_event_ids = frozenset(new_state.values())
-            for sg, events in state_groups_ids.items():
-                if new_state_event_ids == frozenset(e_id for e_id in events):
-                    state_group = sg
-                    break
-            if state_group is None:
-                # Worker instances don't have access to this method, but we want
-                # to set the state_group on the main instance to increase cache
-                # hits.
-                if hasattr(self.store, "get_next_state_group"):
-                    state_group = self.store.get_next_state_group()
-
-            prev_group = None
-            delta_ids = None
-            for old_group, old_ids in state_groups_ids.items():
-                if not set(new_state.iterkeys()) - set(old_ids.iterkeys()):
-                    n_delta_ids = {
-                        k: v
-                        for k, v in new_state.items()
-                        if old_ids.get(k) != v
-                    }
-                    if not delta_ids or len(n_delta_ids) < len(delta_ids):
-                        prev_group = old_group
-                        delta_ids = n_delta_ids
+            with Measure(self.clock, "state.create_group_ids"):
+                # if the new state matches any of the input state groups, we can
+                # use that state group again. Otherwise we will generate a state_id
+                # which will be used as a cache key for future resolutions, but
+                # not get persisted.
+                state_group = None
+                new_state_event_ids = frozenset(new_state.itervalues())
+                for sg, events in state_groups_ids.iteritems():
+                    if new_state_event_ids == frozenset(e_id for e_id in events):
+                        state_group = sg
+                        break
+
+                # TODO: We want to create a state group for this set of events, to
+                # increase cache hits, but we need to make sure that it doesn't
+                # end up as a prev_group without being added to the database
+
+                prev_group = None
+                delta_ids = None
+                for old_group, old_ids in state_groups_ids.iteritems():
+                    if not set(new_state) - set(old_ids):
+                        n_delta_ids = {
+                            k: v
+                            for k, v in new_state.iteritems()
+                            if old_ids.get(k) != v
+                        }
+                        if not delta_ids or len(n_delta_ids) < len(delta_ids):
+                            prev_group = old_group
+                            delta_ids = n_delta_ids
 
             cache = _StateCacheEntry(
                 state=new_state,
@@ -383,30 +524,6 @@ class StateHandler(object):
 
             defer.returnValue(cache)
 
-    def resolve_events(self, state_sets, event):
-        logger.info(
-            "Resolving state for %s with %d groups", event.room_id, len(state_sets)
-        )
-        state_set_ids = [{
-            (ev.type, ev.state_key): ev.event_id
-            for ev in st
-        } for st in state_sets]
-
-        state_map = {
-            ev.event_id: ev
-            for st in state_sets
-            for ev in st
-        }
-
-        with Measure(self.clock, "state._resolve_events"):
-            new_state = resolve_events(state_set_ids, state_map)
-
-        new_state = {
-            key: state_map[ev_id] for key, ev_id in new_state.items()
-        }
-
-        return new_state
-
 
 def _ordered_events(events):
     def key_func(e):
@@ -415,19 +532,17 @@ def _ordered_events(events):
     return sorted(events, key=key_func)
 
 
-def resolve_events(state_sets, state_map_factory):
+def resolve_events_with_state_map(state_sets, state_map):
     """
     Args:
         state_sets(list): List of dicts of (type, state_key) -> event_id,
             which are the different state groups to resolve.
-        state_map_factory(dict|callable): If callable, then will be called
-            with a list of event_ids that are needed, and should return with
-            a Deferred of dict of event_id to event. Otherwise, should be
-            a dict from event_id to event of all events in state_sets.
+        state_map(dict): a dict from event_id to event, for all events in
+            state_sets.
 
     Returns
-        dict[(str, str), synapse.events.FrozenEvent] is a map from
-        (type, state_key) to event.
+        dict[(str, str), str]:
+            a map from (type, state_key) to event_id.
     """
     if len(state_sets) == 1:
         return state_sets[0]
@@ -436,13 +551,6 @@ def resolve_events(state_sets, state_map_factory):
         state_sets,
     )
 
-    if callable(state_map_factory):
-        return _resolve_with_state_fac(
-            unconflicted_state, conflicted_state, state_map_factory
-        )
-
-    state_map = state_map_factory
-
     auth_events = _create_auth_events_from_maps(
         unconflicted_state, conflicted_state, state_map
     )
@@ -456,6 +564,21 @@ def _seperate(state_sets):
     """Takes the state_sets and figures out which keys are conflicted and
     which aren't. i.e., which have multiple different event_ids associated
     with them in different state sets.
+
+    Args:
+        state_sets(list[dict[(str, str), str]]):
+            List of dicts of (type, state_key) -> event_id, which are the
+            different state groups to resolve.
+
+    Returns:
+        (dict[(str, str), str], dict[(str, str), set[str]]):
+            A tuple of (unconflicted_state, conflicted_state), where:
+
+            unconflicted_state is a dict mapping (type, state_key)->event_id
+            for unconflicted state keys.
+
+            conflicted_state is a dict mapping (type, state_key) to a set of
+            event ids for conflicted state keys.
     """
     unconflicted_state = dict(state_sets[0])
     conflicted_state = {}
@@ -486,24 +609,63 @@ def _seperate(state_sets):
 
 
 @defer.inlineCallbacks
-def _resolve_with_state_fac(unconflicted_state, conflicted_state,
-                            state_map_factory):
+def resolve_events_with_factory(state_sets, event_map, state_map_factory):
+    """
+    Args:
+        state_sets(list): List of dicts of (type, state_key) -> event_id,
+            which are the different state groups to resolve.
+
+        event_map(dict[str,FrozenEvent]|None):
+            a dict from event_id to event, for any events that we happen to
+            have in flight (eg, those currently being persisted). This will be
+            used as a starting point fof finding the state we need; any missing
+            events will be requested via state_map_factory.
+
+            If None, all events will be fetched via state_map_factory.
+
+        state_map_factory(func): will be called
+            with a list of event_ids that are needed, and should return with
+            a Deferred of dict of event_id to event.
+
+    Returns
+        Deferred[dict[(str, str), str]]:
+            a map from (type, state_key) to event_id.
+    """
+    if len(state_sets) == 1:
+        defer.returnValue(state_sets[0])
+
+    unconflicted_state, conflicted_state = _seperate(
+        state_sets,
+    )
+
     needed_events = set(
         event_id
         for event_ids in conflicted_state.itervalues()
         for event_id in event_ids
     )
+    if event_map is not None:
+        needed_events -= set(event_map.iterkeys())
 
     logger.info("Asking for %d conflicted events", len(needed_events))
 
+    # dict[str, FrozenEvent]: a map from state event id to event. Only includes
+    # the state events which are in conflict (and those in event_map)
     state_map = yield state_map_factory(needed_events)
+    if event_map is not None:
+        state_map.update(event_map)
 
+    # get the ids of the auth events which allow us to authenticate the
+    # conflicted state, picking only from the unconflicting state.
+    #
+    # dict[(str, str), str]: a map from state key to event id
     auth_events = _create_auth_events_from_maps(
         unconflicted_state, conflicted_state, state_map
     )
 
     new_needed_events = set(auth_events.itervalues())
     new_needed_events -= needed_events
+    if event_map is not None:
+        new_needed_events -= set(event_map.iterkeys())
 
     logger.info("Asking for %d auth events", len(new_needed_events))
 
@@ -541,7 +703,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
 
     auth_events = {
         key: state_map[ev_id]
-        for key, ev_id in auth_event_ids.items()
+        for key, ev_id in auth_event_ids.iteritems()
         if ev_id in state_map
     }
 
@@ -549,7 +711,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
         resolved_state = _resolve_state_events(
             conflicted_state, auth_events
         )
-    except:
+    except Exception:
         logger.exception("Failed to resolve state")
         raise
 
@@ -579,7 +741,7 @@ def _resolve_state_events(conflicted_state, auth_events):
 
     auth_events.update(resolved_state)
 
-    for key, events in conflicted_state.items():
+    for key, events in conflicted_state.iteritems():
         if key[0] == EventTypes.JoinRules:
             logger.debug("Resolving conflicted join rules %r", events)
             resolved_state[key] = _resolve_auth_events(
@@ -589,7 +751,7 @@ def _resolve_state_events(conflicted_state, auth_events):
 
     auth_events.update(resolved_state)
 
-    for key, events in conflicted_state.items():
+    for key, events in conflicted_state.iteritems():
         if key[0] == EventTypes.Member:
             logger.debug("Resolving conflicted member lists %r", events)
             resolved_state[key] = _resolve_auth_events(
@@ -599,7 +761,7 @@ def _resolve_state_events(conflicted_state, auth_events):
 
     auth_events.update(resolved_state)
 
-    for key, events in conflicted_state.items():
+    for key, events in conflicted_state.iteritems():
         if key not in resolved_state:
             logger.debug("Resolving conflicted state %r:%r", key, events)
             resolved_state[key] = _resolve_normal_events(
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index d604e7668f..8cdfd50f90 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,13 +14,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
 from synapse.storage.devices import DeviceStore
 from .appservice import (
     ApplicationServiceStore, ApplicationServiceTransactionStore
 )
-from ._base import LoggingTransaction
 from .directory import DirectoryStore
 from .events import EventsStore
 from .presence import PresenceStore, UserPresenceState
@@ -37,7 +35,7 @@ from .media_repository import MediaRepositoryStore
 from .rejections import RejectionsStore
 from .event_push_actions import EventPushActionsStore
 from .deviceinbox import DeviceInboxStore
-
+from .group_server import GroupServerStore
 from .state import StateStore
 from .signatures import SignatureStore
 from .filtering import FilteringStore
@@ -49,6 +47,7 @@ from .tags import TagsStore
 from .account_data import AccountDataStore
 from .openid import OpenIdStore
 from .client_ips import ClientIpStore
+from .user_directory import UserDirectoryStore
 
 from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
 from .engines import PostgresEngine
@@ -86,6 +85,8 @@ class DataStore(RoomMemberStore, RoomStore,
                 ClientIpStore,
                 DeviceStore,
                 DeviceInboxStore,
+                UserDirectoryStore,
+                GroupServerStore,
                 ):
 
     def __init__(self, db_conn, hs):
@@ -101,12 +102,6 @@ class DataStore(RoomMemberStore, RoomStore,
             db_conn, "events", "stream_ordering", step=-1,
             extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
         )
-        self._receipts_id_gen = StreamIdGenerator(
-            db_conn, "receipts_linearized", "stream_id"
-        )
-        self._account_data_id_gen = StreamIdGenerator(
-            db_conn, "account_data_max_stream_id", "stream_id"
-        )
         self._presence_id_gen = StreamIdGenerator(
             db_conn, "presence_stream", "stream_id"
         )
@@ -121,7 +116,6 @@ class DataStore(RoomMemberStore, RoomStore,
         )
 
         self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
-        self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
         self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
@@ -133,6 +127,9 @@ class DataStore(RoomMemberStore, RoomStore,
             db_conn, "pushers", "id",
             extra_tables=[("deleted_pushers", "stream_id")],
         )
+        self._group_updates_id_gen = StreamIdGenerator(
+            db_conn, "local_group_updates", "stream_id",
+        )
 
         if isinstance(self.database_engine, PostgresEngine):
             self._cache_id_gen = StreamIdGenerator(
@@ -141,27 +138,6 @@ class DataStore(RoomMemberStore, RoomStore,
         else:
             self._cache_id_gen = None
 
-        events_max = self._stream_id_gen.get_current_token()
-        event_cache_prefill, min_event_val = self._get_cache_dict(
-            db_conn, "events",
-            entity_column="room_id",
-            stream_column="stream_ordering",
-            max_value=events_max,
-        )
-        self._events_stream_cache = StreamChangeCache(
-            "EventsRoomStreamChangeCache", min_event_val,
-            prefilled_cache=event_cache_prefill,
-        )
-
-        self._membership_stream_cache = StreamChangeCache(
-            "MembershipStreamChangeCache", events_max,
-        )
-
-        account_max = self._account_data_id_gen.get_current_token()
-        self._account_data_stream_cache = StreamChangeCache(
-            "AccountDataAndTagsChangeCache", account_max,
-        )
-
         self._presence_on_startup = self._get_active_presence(db_conn)
 
         presence_cache_prefill, min_presence_val = self._get_cache_dict(
@@ -175,18 +151,6 @@ class DataStore(RoomMemberStore, RoomStore,
             prefilled_cache=presence_cache_prefill
         )
 
-        push_rules_prefill, push_rules_id = self._get_cache_dict(
-            db_conn, "push_rules_stream",
-            entity_column="user_id",
-            stream_column="stream_id",
-            max_value=self._push_rules_stream_id_gen.get_current_token()[0],
-        )
-
-        self.push_rules_stream_cache = StreamChangeCache(
-            "PushRulesStreamChangeCache", push_rules_id,
-            prefilled_cache=push_rules_prefill,
-        )
-
         max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
         device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
             db_conn, "device_inbox",
@@ -221,23 +185,35 @@ class DataStore(RoomMemberStore, RoomStore,
             "DeviceListFederationStreamChangeCache", device_list_max,
         )
 
-        cur = LoggingTransaction(
-            db_conn.cursor(),
-            name="_find_stream_orderings_for_times_txn",
-            database_engine=self.database_engine,
-            after_callbacks=[]
+        events_max = self._stream_id_gen.get_current_token()
+        curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
+            db_conn, "current_state_delta_stream",
+            entity_column="room_id",
+            stream_column="stream_id",
+            max_value=events_max,  # As we share the stream id with events token
+            limit=1000,
+        )
+        self._curr_state_delta_stream_cache = StreamChangeCache(
+            "_curr_state_delta_stream_cache", min_curr_state_delta_id,
+            prefilled_cache=curr_state_delta_prefill,
         )
-        self._find_stream_orderings_for_times_txn(cur)
-        cur.close()
 
-        self.find_stream_orderings_looping_call = self._clock.looping_call(
-            self._find_stream_orderings_for_times, 60 * 60 * 1000
+        _group_updates_prefill, min_group_updates_id = self._get_cache_dict(
+            db_conn, "local_group_updates",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=self._group_updates_id_gen.get_current_token(),
+            limit=1000,
+        )
+        self._group_updates_stream_cache = StreamChangeCache(
+            "_group_updates_stream_cache", min_group_updates_id,
+            prefilled_cache=_group_updates_prefill,
         )
 
         self._stream_order_on_start = self.get_room_max_stream_ordering()
         self._min_stream_order_on_start = self.get_room_min_stream_ordering()
 
-        super(DataStore, self).__init__(hs)
+        super(DataStore, self).__init__(db_conn, hs)
 
     def take_presence_startup_info(self):
         active_on_startup = self._presence_on_startup
@@ -266,36 +242,110 @@ class DataStore(RoomMemberStore, RoomStore,
 
         return [UserPresenceState(**row) for row in rows]
 
-    @defer.inlineCallbacks
     def count_daily_users(self):
         """
         Counts the number of users who used this homeserver in the last 24 hours.
         """
         def _count_users(txn):
-            txn.execute(
-                "SELECT COUNT(DISTINCT user_id) AS users"
-                " FROM user_ips"
-                " WHERE last_seen > ?",
-                # This is close enough to a day for our purposes.
-                (int(self._clock.time_msec()) - (1000 * 60 * 60 * 24),)
-            )
-            rows = self.cursor_to_dict(txn)
-            if rows:
-                return rows[0]["users"]
-            return 0
+            yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
 
-        ret = yield self.runInteraction("count_users", _count_users)
-        defer.returnValue(ret)
+            sql = """
+                SELECT COALESCE(count(*), 0) FROM (
+                    SELECT user_id FROM user_ips
+                    WHERE last_seen > ?
+                    GROUP BY user_id
+                ) u
+            """
 
-    def get_user_ip_and_agents(self, user):
-        return self._simple_select_list(
-            table="user_ips",
-            keyvalues={"user_id": user.to_string()},
-            retcols=[
-                "access_token", "ip", "user_agent", "last_seen"
-            ],
-            desc="get_user_ip_and_agents",
-        )
+            txn.execute(sql, (yesterday,))
+            count, = txn.fetchone()
+            return count
+
+        return self.runInteraction("count_users", _count_users)
+
+    def count_r30_users(self):
+        """
+        Counts the number of 30 day retained users, defined as:-
+         * Users who have created their accounts more than 30 days ago
+         * Where last seen at most 30 days ago
+         * Where account creation and last_seen are > 30 days apart
+
+         Returns counts globaly for a given user as well as breaking
+         by platform
+        """
+        def _count_r30_users(txn):
+            thirty_days_in_secs = 86400 * 30
+            now = int(self._clock.time())
+            thirty_days_ago_in_secs = now - thirty_days_in_secs
+
+            sql = """
+                SELECT platform, COALESCE(count(*), 0) FROM (
+                     SELECT
+                        users.name, platform, users.creation_ts * 1000,
+                        MAX(uip.last_seen)
+                     FROM users
+                     INNER JOIN (
+                         SELECT
+                         user_id,
+                         last_seen,
+                         CASE
+                             WHEN user_agent LIKE '%%Android%%' THEN 'android'
+                             WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
+                             WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
+                             WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
+                             WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
+                             ELSE 'unknown'
+                         END
+                         AS platform
+                         FROM user_ips
+                     ) uip
+                     ON users.name = uip.user_id
+                     AND users.appservice_id is NULL
+                     AND users.creation_ts < ?
+                     AND uip.last_seen/1000 > ?
+                     AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+                     GROUP BY users.name, platform, users.creation_ts
+                ) u GROUP BY platform
+            """
+
+            results = {}
+            txn.execute(sql, (thirty_days_ago_in_secs,
+                              thirty_days_ago_in_secs))
+
+            for row in txn:
+                if row[0] is 'unknown':
+                    pass
+                results[row[0]] = row[1]
+
+            sql = """
+                SELECT COALESCE(count(*), 0) FROM (
+                    SELECT users.name, users.creation_ts * 1000,
+                                                        MAX(uip.last_seen)
+                    FROM users
+                    INNER JOIN (
+                        SELECT
+                        user_id,
+                        last_seen
+                        FROM user_ips
+                    ) uip
+                    ON users.name = uip.user_id
+                    AND appservice_id is NULL
+                    AND users.creation_ts < ?
+                    AND uip.last_seen/1000 > ?
+                    AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+                    GROUP BY users.name, users.creation_ts
+                ) u
+            """
+
+            txn.execute(sql, (thirty_days_ago_in_secs,
+                              thirty_days_ago_in_secs))
+
+            count, = txn.fetchone()
+            results['all'] = count
+
+            return results
+
+        return self.runInteraction("count_r30_users", _count_r30_users)
 
     def get_users(self):
         """Function to reterive a list of users in users table.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b0dc391190..2262776ab2 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -16,9 +16,7 @@ import logging
 
 from synapse.api.errors import StoreError
 from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
-from synapse.util.caches.dictionary_cache import DictionaryCache
 from synapse.util.caches.descriptors import Cache
-from synapse.util.caches import intern_dict
 from synapse.storage.engines import PostgresEngine
 import synapse.metrics
 
@@ -28,10 +26,6 @@ from twisted.internet import defer
 import sys
 import time
 import threading
-import os
-
-
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
 
 
 logger = logging.getLogger(__name__)
@@ -53,20 +47,27 @@ class LoggingTransaction(object):
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
     method."""
-    __slots__ = ["txn", "name", "database_engine", "after_callbacks"]
+    __slots__ = [
+        "txn", "name", "database_engine", "after_callbacks", "exception_callbacks",
+    ]
 
-    def __init__(self, txn, name, database_engine, after_callbacks):
+    def __init__(self, txn, name, database_engine, after_callbacks,
+                 exception_callbacks):
         object.__setattr__(self, "txn", txn)
         object.__setattr__(self, "name", name)
         object.__setattr__(self, "database_engine", database_engine)
         object.__setattr__(self, "after_callbacks", after_callbacks)
+        object.__setattr__(self, "exception_callbacks", exception_callbacks)
 
-    def call_after(self, callback, *args):
+    def call_after(self, callback, *args, **kwargs):
         """Call the given callback on the main twisted thread after the
         transaction has finished. Used to invalidate the caches on the
         correct thread.
         """
-        self.after_callbacks.append((callback, args))
+        self.after_callbacks.append((callback, args, kwargs))
+
+    def call_on_exception(self, callback, *args, **kwargs):
+        self.exception_callbacks.append((callback, args, kwargs))
 
     def __getattr__(self, name):
         return getattr(self.txn, name)
@@ -74,13 +75,22 @@ class LoggingTransaction(object):
     def __setattr__(self, name, value):
         setattr(self.txn, name, value)
 
+    def __iter__(self):
+        return self.txn.__iter__()
+
     def execute(self, sql, *args):
         self._do_execute(self.txn.execute, sql, *args)
 
     def executemany(self, sql, *args):
         self._do_execute(self.txn.executemany, sql, *args)
 
+    def _make_sql_one_line(self, sql):
+        "Strip newlines out of SQL so that the loggers in the DB are on one line"
+        return " ".join(l.strip() for l in sql.splitlines() if l.strip())
+
     def _do_execute(self, func, sql, *args):
+        sql = self._make_sql_one_line(sql)
+
         # TODO(paul): Maybe use 'info' and 'debug' for values?
         sql_logger.debug("[SQL] {%s} %s", self.name, sql)
 
@@ -91,7 +101,7 @@ class LoggingTransaction(object):
                     "[SQL values] {%s} %r",
                     self.name, args[0]
                 )
-            except:
+            except Exception:
                 # Don't let logging failures stop SQL from working
                 pass
 
@@ -127,7 +137,7 @@ class PerformanceCounters(object):
 
     def interval(self, interval_duration, limit=3):
         counters = []
-        for name, (count, cum_time) in self.current_counters.items():
+        for name, (count, cum_time) in self.current_counters.iteritems():
             prev_count, prev_time = self.previous_counters.get(name, (0, 0))
             counters.append((
                 (cum_time - prev_time) / interval_duration,
@@ -150,7 +160,7 @@ class PerformanceCounters(object):
 class SQLBaseStore(object):
     _TXN_ID = 0
 
-    def __init__(self, hs):
+    def __init__(self, db_conn, hs):
         self.hs = hs
         self._clock = hs.get_clock()
         self._db_pool = hs.get_db_pool()
@@ -168,10 +178,6 @@ class SQLBaseStore(object):
         self._get_event_cache = Cache("*getEvent*", keylen=3,
                                       max_entries=hs.config.event_cache_size)
 
-        self._state_group_cache = DictionaryCache(
-            "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
-        )
-
         self._event_fetch_lock = threading.Condition()
         self._event_fetch_list = []
         self._event_fetch_ongoing = 0
@@ -209,8 +215,8 @@ class SQLBaseStore(object):
 
         self._clock.looping_call(loop, 10000)
 
-    def _new_transaction(self, conn, desc, after_callbacks, logging_context,
-                         func, *args, **kwargs):
+    def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
+                         logging_context, func, *args, **kwargs):
         start = time.time() * 1000
         txn_id = self._TXN_ID
 
@@ -229,7 +235,8 @@ class SQLBaseStore(object):
                 try:
                     txn = conn.cursor()
                     txn = LoggingTransaction(
-                        txn, name, self.database_engine, after_callbacks
+                        txn, name, self.database_engine, after_callbacks,
+                        exception_callbacks,
                     )
                     r = func(txn, *args, **kwargs)
                     conn.commit()
@@ -284,47 +291,66 @@ class SQLBaseStore(object):
 
     @defer.inlineCallbacks
     def runInteraction(self, desc, func, *args, **kwargs):
-        """Wraps the .runInteraction() method on the underlying db_pool."""
-        current_context = LoggingContext.current_context()
+        """Starts a transaction on the database and runs a given function
 
-        start_time = time.time() * 1000
+        Arguments:
+            desc (str): description of the transaction, for logging and metrics
+            func (func): callback function, which will be called with a
+                database transaction (twisted.enterprise.adbapi.Transaction) as
+                its first argument, followed by `args` and `kwargs`.
+
+            args (list): positional args to pass to `func`
+            kwargs (dict): named args to pass to `func`
+
+        Returns:
+            Deferred: The result of func
+        """
+        current_context = LoggingContext.current_context()
 
         after_callbacks = []
+        exception_callbacks = []
 
         def inner_func(conn, *args, **kwargs):
-            with LoggingContext("runInteraction") as context:
-                sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+            return self._new_transaction(
+                conn, desc, after_callbacks, exception_callbacks, current_context,
+                func, *args, **kwargs
+            )
 
-                if self.database_engine.is_connection_closed(conn):
-                    logger.debug("Reconnecting closed database connection")
-                    conn.reconnect()
+        try:
+            result = yield self.runWithConnection(inner_func, *args, **kwargs)
 
-                current_context.copy_to(context)
-                return self._new_transaction(
-                    conn, desc, after_callbacks, current_context,
-                    func, *args, **kwargs
-                )
+            for after_callback, after_args, after_kwargs in after_callbacks:
+                after_callback(*after_args, **after_kwargs)
+        except:  # noqa: E722, as we reraise the exception this is fine.
+            for after_callback, after_args, after_kwargs in exception_callbacks:
+                after_callback(*after_args, **after_kwargs)
+            raise
 
-        try:
-            with PreserveLoggingContext():
-                result = yield self._db_pool.runWithConnection(
-                    inner_func, *args, **kwargs
-                )
-        finally:
-            for after_callback, after_args in after_callbacks:
-                after_callback(*after_args)
         defer.returnValue(result)
 
     @defer.inlineCallbacks
     def runWithConnection(self, func, *args, **kwargs):
-        """Wraps the .runInteraction() method on the underlying db_pool."""
+        """Wraps the .runWithConnection() method on the underlying db_pool.
+
+        Arguments:
+            func (func): callback function, which will be called with a
+                database connection (twisted.enterprise.adbapi.Connection) as
+                its first argument, followed by `args` and `kwargs`.
+            args (list): positional args to pass to `func`
+            kwargs (dict): named args to pass to `func`
+
+        Returns:
+            Deferred: The result of func
+        """
         current_context = LoggingContext.current_context()
 
         start_time = time.time() * 1000
 
         def inner_func(conn, *args, **kwargs):
             with LoggingContext("runWithConnection") as context:
-                sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+                sched_duration_ms = time.time() * 1000 - start_time
+                sql_scheduling_timer.inc_by(sched_duration_ms)
+                current_context.add_database_scheduled(sched_duration_ms)
 
                 if self.database_engine.is_connection_closed(conn):
                     logger.debug("Reconnecting closed database connection")
@@ -350,9 +376,9 @@ class SQLBaseStore(object):
         Returns:
             A list of dicts where the key is the column header.
         """
-        col_headers = list(column[0] for column in cursor.description)
+        col_headers = list(intern(str(column[0])) for column in cursor.description)
         results = list(
-            intern_dict(dict(zip(col_headers, row))) for row in cursor.fetchall()
+            dict(zip(col_headers, row)) for row in cursor
         )
         return results
 
@@ -417,6 +443,11 @@ class SQLBaseStore(object):
 
         txn.execute(sql, vals)
 
+    def _simple_insert_many(self, table, values, desc):
+        return self.runInteraction(
+            desc, self._simple_insert_many_txn, table, values
+        )
+
     @staticmethod
     def _simple_insert_many_txn(txn, table, values):
         if not values:
@@ -452,23 +483,53 @@ class SQLBaseStore(object):
 
         txn.executemany(sql, vals)
 
+    @defer.inlineCallbacks
     def _simple_upsert(self, table, keyvalues, values,
                        insertion_values={}, desc="_simple_upsert", lock=True):
         """
+
+        `lock` should generally be set to True (the default), but can be set
+        to False if either of the following are true:
+
+        * there is a UNIQUE INDEX on the key columns. In this case a conflict
+          will cause an IntegrityError in which case this function will retry
+          the update.
+
+        * we somehow know that we are the only thread which will be updating
+          this table.
+
         Args:
             table (str): The table to upsert into
             keyvalues (dict): The unique key tables and their new values
             values (dict): The nonunique columns and their new values
-            insertion_values (dict): key/values to use when inserting
+            insertion_values (dict): additional key/values to use only when
+                inserting
+            lock (bool): True to lock the table when doing the upsert.
         Returns:
             Deferred(bool): True if a new entry was created, False if an
                 existing one was updated.
         """
-        return self.runInteraction(
-            desc,
-            self._simple_upsert_txn, table, keyvalues, values, insertion_values,
-            lock
-        )
+        attempts = 0
+        while True:
+            try:
+                result = yield self.runInteraction(
+                    desc,
+                    self._simple_upsert_txn, table, keyvalues, values, insertion_values,
+                    lock=lock
+                )
+                defer.returnValue(result)
+            except self.database_engine.module.IntegrityError as e:
+                attempts += 1
+                if attempts >= 5:
+                    # don't retry forever, because things other than races
+                    # can cause IntegrityErrors
+                    raise
+
+                # presumably we raced with another transaction: let's retry.
+                logger.warn(
+                    "IntegrityError when upserting into %s; retrying: %s",
+                    table, e
+                )
 
     def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
                            lock=True):
@@ -476,45 +537,38 @@ class SQLBaseStore(object):
         if lock:
             self.database_engine.lock_table(txn, table)
 
-        # Try to update
+        # First try to update.
         sql = "UPDATE %s SET %s WHERE %s" % (
             table,
             ", ".join("%s = ?" % (k,) for k in values),
             " AND ".join("%s = ?" % (k,) for k in keyvalues)
         )
         sqlargs = values.values() + keyvalues.values()
-        logger.debug(
-            "[SQL] %s Args=%s",
-            sql, sqlargs,
-        )
 
         txn.execute(sql, sqlargs)
-        if txn.rowcount == 0:
-            # We didn't update and rows so insert a new one
-            allvalues = {}
-            allvalues.update(keyvalues)
-            allvalues.update(values)
-            allvalues.update(insertion_values)
+        if txn.rowcount > 0:
+            # successfully updated at least one row.
+            return False
 
-            sql = "INSERT INTO %s (%s) VALUES (%s)" % (
-                table,
-                ", ".join(k for k in allvalues),
-                ", ".join("?" for _ in allvalues)
-            )
-            logger.debug(
-                "[SQL] %s Args=%s",
-                sql, keyvalues.values(),
-            )
-            txn.execute(sql, allvalues.values())
+        # We didn't update any rows so insert a new one
+        allvalues = {}
+        allvalues.update(keyvalues)
+        allvalues.update(values)
+        allvalues.update(insertion_values)
 
-            return True
-        else:
-            return False
+        sql = "INSERT INTO %s (%s) VALUES (%s)" % (
+            table,
+            ", ".join(k for k in allvalues),
+            ", ".join("?" for _ in allvalues)
+        )
+        txn.execute(sql, allvalues.values())
+        # successfully inserted
+        return True
 
     def _simple_select_one(self, table, keyvalues, retcols,
                            allow_none=False, desc="_simple_select_one"):
         """Executes a SELECT query on the named table, which is expected to
-        return a single row, returning a single column from it.
+        return a single row, returning multiple columns from it.
 
         Args:
             table : string giving the table name
@@ -567,22 +621,20 @@ class SQLBaseStore(object):
 
     @staticmethod
     def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
-        if keyvalues:
-            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
-        else:
-            where = ""
-
         sql = (
-            "SELECT %(retcol)s FROM %(table)s %(where)s"
+            "SELECT %(retcol)s FROM %(table)s"
         ) % {
             "retcol": retcol,
             "table": table,
-            "where": where,
         }
 
-        txn.execute(sql, keyvalues.values())
+        if keyvalues:
+            sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
+            txn.execute(sql, keyvalues.values())
+        else:
+            txn.execute(sql)
 
-        return [r[0] for r in txn.fetchall()]
+        return [r[0] for r in txn]
 
     def _simple_select_onecol(self, table, keyvalues, retcol,
                               desc="_simple_select_onecol"):
@@ -591,7 +643,7 @@ class SQLBaseStore(object):
 
         Args:
             table (str): table name
-            keyvalues (dict): column names and values to select the rows with
+            keyvalues (dict|None): column names and values to select the rows with
             retcol (str): column whos value we wish to retrieve.
 
         Returns:
@@ -715,7 +767,7 @@ class SQLBaseStore(object):
         )
         values.extend(iterable)
 
-        for key, value in keyvalues.items():
+        for key, value in keyvalues.iteritems():
             clauses.append("%s = ?" % (key,))
             values.append(value)
 
@@ -728,6 +780,33 @@ class SQLBaseStore(object):
         txn.execute(sql, values)
         return cls.cursor_to_dict(txn)
 
+    def _simple_update(self, table, keyvalues, updatevalues, desc):
+        return self.runInteraction(
+            desc,
+            self._simple_update_txn,
+            table, keyvalues, updatevalues,
+        )
+
+    @staticmethod
+    def _simple_update_txn(txn, table, keyvalues, updatevalues):
+        if keyvalues:
+            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
+        else:
+            where = ""
+
+        update_sql = "UPDATE %s SET %s %s" % (
+            table,
+            ", ".join("%s = ?" % (k,) for k in updatevalues),
+            where,
+        )
+
+        txn.execute(
+            update_sql,
+            updatevalues.values() + keyvalues.values()
+        )
+
+        return txn.rowcount
+
     def _simple_update_one(self, table, keyvalues, updatevalues,
                            desc="_simple_update_one"):
         """Executes an UPDATE query on the named table, setting new values for
@@ -753,27 +832,13 @@ class SQLBaseStore(object):
             table, keyvalues, updatevalues,
         )
 
-    @staticmethod
-    def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
-        if keyvalues:
-            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
-        else:
-            where = ""
-
-        update_sql = "UPDATE %s SET %s %s" % (
-            table,
-            ", ".join("%s = ?" % (k,) for k in updatevalues),
-            where,
-        )
-
-        txn.execute(
-            update_sql,
-            updatevalues.values() + keyvalues.values()
-        )
+    @classmethod
+    def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+        rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
 
-        if txn.rowcount == 0:
+        if rowcount == 0:
             raise StoreError(404, "No row found")
-        if txn.rowcount > 1:
+        if rowcount > 1:
             raise StoreError(500, "More than one row matched")
 
     @staticmethod
@@ -843,6 +908,47 @@ class SQLBaseStore(object):
 
         return txn.execute(sql, keyvalues.values())
 
+    def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
+        return self.runInteraction(
+            desc, self._simple_delete_many_txn, table, column, iterable, keyvalues
+        )
+
+    @staticmethod
+    def _simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+        """Executes a DELETE query on the named table.
+
+        Filters rows by if value of `column` is in `iterable`.
+
+        Args:
+            txn : Transaction object
+            table : string giving the table name
+            column : column name to test for inclusion against `iterable`
+            iterable : list
+            keyvalues : dict of column names and values to select the rows with
+        """
+        if not iterable:
+            return
+
+        sql = "DELETE FROM %s" % table
+
+        clauses = []
+        values = []
+        clauses.append(
+            "%s IN (%s)" % (column, ",".join("?" for _ in iterable))
+        )
+        values.extend(iterable)
+
+        for key, value in keyvalues.iteritems():
+            clauses.append("%s = ?" % (key,))
+            values.append(value)
+
+        if clauses:
+            sql = "%s WHERE %s" % (
+                sql,
+                " AND ".join(clauses),
+            )
+        return txn.execute(sql, values)
+
     def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
                         max_value, limit=100000):
         # Fetch a mapping of room_id -> max stream position for "recent" rooms.
@@ -863,16 +969,16 @@ class SQLBaseStore(object):
 
         txn = db_conn.cursor()
         txn.execute(sql, (int(max_value),))
-        rows = txn.fetchall()
-        txn.close()
 
         cache = {
             row[0]: int(row[1])
-            for row in rows
+            for row in txn
         }
 
+        txn.close()
+
         if cache:
-            min_val = min(cache.values())
+            min_val = min(cache.itervalues())
         else:
             min_val = max_value
 
@@ -895,6 +1001,7 @@ class SQLBaseStore(object):
             # __exit__ called after the transaction finishes.
             ctx = self._cache_id_gen.get_next()
             stream_id = ctx.__enter__()
+            txn.call_on_exception(ctx.__exit__, None, None, None)
             txn.call_after(ctx.__exit__, None, None, None)
             txn.call_after(self.hs.get_notifier().on_new_replication_data)
 
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 3fa226e92d..f83ff0454a 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,18 +14,46 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
 from twisted.internet import defer
 
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.util.id_generators import StreamIdGenerator
+
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
 
-import ujson as json
+import abc
+import simplejson as json
 import logging
 
 logger = logging.getLogger(__name__)
 
 
-class AccountDataStore(SQLBaseStore):
+class AccountDataWorkerStore(SQLBaseStore):
+    """This is an abstract base class where subclasses must implement
+    `get_max_account_data_stream_id` which can be called in the initializer.
+    """
+
+    # This ABCMeta metaclass ensures that we cannot be instantiated without
+    # the abstract methods being implemented.
+    __metaclass__ = abc.ABCMeta
+
+    def __init__(self, db_conn, hs):
+        account_max = self.get_max_account_data_stream_id()
+        self._account_data_stream_cache = StreamChangeCache(
+            "AccountDataAndTagsChangeCache", account_max,
+        )
+
+        super(AccountDataWorkerStore, self).__init__(db_conn, hs)
+
+    @abc.abstractmethod
+    def get_max_account_data_stream_id(self):
+        """Get the current max stream ID for account data stream
+
+        Returns:
+            int
+        """
+        raise NotImplementedError()
 
     @cached()
     def get_account_data_for_user(self, user_id):
@@ -63,7 +92,7 @@ class AccountDataStore(SQLBaseStore):
             "get_account_data_for_user", get_account_data_for_user_txn
         )
 
-    @cachedInlineCallbacks(num_args=2)
+    @cachedInlineCallbacks(num_args=2, max_entries=5000)
     def get_global_account_data_by_type_for_user(self, data_type, user_id):
         """
         Returns:
@@ -104,6 +133,7 @@ class AccountDataStore(SQLBaseStore):
             for row in rows
         })
 
+    @cached(num_args=2)
     def get_account_data_for_room(self, user_id, room_id):
         """Get all the client account_data for a user for a room.
 
@@ -127,6 +157,38 @@ class AccountDataStore(SQLBaseStore):
             "get_account_data_for_room", get_account_data_for_room_txn
         )
 
+    @cached(num_args=3, max_entries=5000)
+    def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
+        """Get the client account_data of given type for a user for a room.
+
+        Args:
+            user_id(str): The user to get the account_data for.
+            room_id(str): The room to get the account_data for.
+            account_data_type (str): The account data type to get.
+        Returns:
+            A deferred of the room account_data for that type, or None if
+            there isn't any set.
+        """
+        def get_account_data_for_room_and_type_txn(txn):
+            content_json = self._simple_select_one_onecol_txn(
+                txn,
+                table="room_account_data",
+                keyvalues={
+                    "user_id": user_id,
+                    "room_id": room_id,
+                    "account_data_type": account_data_type,
+                },
+                retcol="content",
+                allow_none=True
+            )
+
+            return json.loads(content_json) if content_json else None
+
+        return self.runInteraction(
+            "get_account_data_for_room_and_type",
+            get_account_data_for_room_and_type_txn,
+        )
+
     def get_all_updated_account_data(self, last_global_id, last_room_id,
                                      current_id, limit):
         """Get all the client account_data that has changed on the server
@@ -182,7 +244,7 @@ class AccountDataStore(SQLBaseStore):
             txn.execute(sql, (user_id, stream_id))
 
             global_account_data = {
-                row[0]: json.loads(row[1]) for row in txn.fetchall()
+                row[0]: json.loads(row[1]) for row in txn
             }
 
             sql = (
@@ -193,7 +255,7 @@ class AccountDataStore(SQLBaseStore):
             txn.execute(sql, (user_id, stream_id))
 
             account_data_by_room = {}
-            for row in txn.fetchall():
+            for row in txn:
                 room_account_data = account_data_by_room.setdefault(row[0], {})
                 room_account_data[row[1]] = json.loads(row[2])
 
@@ -209,6 +271,36 @@ class AccountDataStore(SQLBaseStore):
             "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
         )
 
+    @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
+    def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
+        ignored_account_data = yield self.get_global_account_data_by_type_for_user(
+            "m.ignored_user_list", ignorer_user_id,
+            on_invalidate=cache_context.invalidate,
+        )
+        if not ignored_account_data:
+            defer.returnValue(False)
+
+        defer.returnValue(
+            ignored_user_id in ignored_account_data.get("ignored_users", {})
+        )
+
+
+class AccountDataStore(AccountDataWorkerStore):
+    def __init__(self, db_conn, hs):
+        self._account_data_id_gen = StreamIdGenerator(
+            db_conn, "account_data_max_stream_id", "stream_id"
+        )
+
+        super(AccountDataStore, self).__init__(db_conn, hs)
+
+    def get_max_account_data_stream_id(self):
+        """Get the current max stream id for the private user data stream
+
+        Returns:
+            A deferred int.
+        """
+        return self._account_data_id_gen.get_current_token()
+
     @defer.inlineCallbacks
     def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
         """Add some account_data to a room for a user.
@@ -222,9 +314,12 @@ class AccountDataStore(SQLBaseStore):
         """
         content_json = json.dumps(content)
 
-        def add_account_data_txn(txn, next_id):
-            self._simple_upsert_txn(
-                txn,
+        with self._account_data_id_gen.get_next() as next_id:
+            # no need to lock here as room_account_data has a unique constraint
+            # on (user_id, room_id, account_data_type) so _simple_upsert will
+            # retry if there is a conflict.
+            yield self._simple_upsert(
+                desc="add_room_account_data",
                 table="room_account_data",
                 keyvalues={
                     "user_id": user_id,
@@ -234,18 +329,23 @@ class AccountDataStore(SQLBaseStore):
                 values={
                     "stream_id": next_id,
                     "content": content_json,
-                }
-            )
-            txn.call_after(
-                self._account_data_stream_cache.entity_has_changed,
-                user_id, next_id,
+                },
+                lock=False,
             )
-            txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
-            self._update_max_stream_id(txn, next_id)
 
-        with self._account_data_id_gen.get_next() as next_id:
-            yield self.runInteraction(
-                "add_room_account_data", add_account_data_txn, next_id
+            # it's theoretically possible for the above to succeed and the
+            # below to fail - in which case we might reuse a stream id on
+            # restart, and the above update might not get propagated. That
+            # doesn't sound any worse than the whole update getting lost,
+            # which is what would happen if we combined the two into one
+            # transaction.
+            yield self._update_max_stream_id(next_id)
+
+            self._account_data_stream_cache.entity_has_changed(user_id, next_id)
+            self.get_account_data_for_user.invalidate((user_id,))
+            self.get_account_data_for_room.invalidate((user_id, room_id,))
+            self.get_account_data_for_room_and_type.prefill(
+                (user_id, room_id, account_data_type,), content,
             )
 
         result = self._account_data_id_gen.get_current_token()
@@ -263,9 +363,12 @@ class AccountDataStore(SQLBaseStore):
         """
         content_json = json.dumps(content)
 
-        def add_account_data_txn(txn, next_id):
-            self._simple_upsert_txn(
-                txn,
+        with self._account_data_id_gen.get_next() as next_id:
+            # no need to lock here as account_data has a unique constraint on
+            # (user_id, account_data_type) so _simple_upsert will retry if
+            # there is a conflict.
+            yield self._simple_upsert(
+                desc="add_user_account_data",
                 table="account_data",
                 keyvalues={
                     "user_id": user_id,
@@ -274,37 +377,43 @@ class AccountDataStore(SQLBaseStore):
                 values={
                     "stream_id": next_id,
                     "content": content_json,
-                }
+                },
+                lock=False,
             )
-            txn.call_after(
-                self._account_data_stream_cache.entity_has_changed,
+
+            # it's theoretically possible for the above to succeed and the
+            # below to fail - in which case we might reuse a stream id on
+            # restart, and the above update might not get propagated. That
+            # doesn't sound any worse than the whole update getting lost,
+            # which is what would happen if we combined the two into one
+            # transaction.
+            yield self._update_max_stream_id(next_id)
+
+            self._account_data_stream_cache.entity_has_changed(
                 user_id, next_id,
             )
-            txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
-            txn.call_after(
-                self.get_global_account_data_by_type_for_user.invalidate,
+            self.get_account_data_for_user.invalidate((user_id,))
+            self.get_global_account_data_by_type_for_user.invalidate(
                 (account_data_type, user_id,)
             )
-            self._update_max_stream_id(txn, next_id)
-
-        with self._account_data_id_gen.get_next() as next_id:
-            yield self.runInteraction(
-                "add_user_account_data", add_account_data_txn, next_id
-            )
 
         result = self._account_data_id_gen.get_current_token()
         defer.returnValue(result)
 
-    def _update_max_stream_id(self, txn, next_id):
+    def _update_max_stream_id(self, next_id):
         """Update the max stream_id
 
         Args:
-            txn: The database cursor
             next_id(int): The the revision to advance to.
         """
-        update_max_id_sql = (
-            "UPDATE account_data_max_stream_id"
-            " SET stream_id = ?"
-            " WHERE stream_id < ?"
+        def _update(txn):
+            update_max_id_sql = (
+                "UPDATE account_data_max_stream_id"
+                " SET stream_id = ?"
+                " WHERE stream_id < ?"
+            )
+            txn.execute(update_max_id_sql, (next_id, next_id))
+        return self.runInteraction(
+            "update_account_data_max_stream_id",
+            _update,
         )
-        txn.execute(update_max_id_sql, (next_id, next_id))
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 514570561f..12ea8a158c 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,39 +14,58 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import re
 import simplejson as json
 from twisted.internet import defer
 
-from synapse.api.constants import Membership
 from synapse.appservice import AppServiceTransaction
 from synapse.config.appservice import load_appservices
-from synapse.storage.roommember import RoomsForUser
+from synapse.storage.events import EventsWorkerStore
 from ._base import SQLBaseStore
 
 
 logger = logging.getLogger(__name__)
 
 
-class ApplicationServiceStore(SQLBaseStore):
+def _make_exclusive_regex(services_cache):
+    # We precompie a regex constructed from all the regexes that the AS's
+    # have registered for exclusive users.
+    exclusive_user_regexes = [
+        regex.pattern
+        for service in services_cache
+        for regex in service.get_exlusive_user_regexes()
+    ]
+    if exclusive_user_regexes:
+        exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
+        exclusive_user_regex = re.compile(exclusive_user_regex)
+    else:
+        # We handle this case specially otherwise the constructed regex
+        # will always match
+        exclusive_user_regex = None
 
-    def __init__(self, hs):
-        super(ApplicationServiceStore, self).__init__(hs)
-        self.hostname = hs.hostname
+    return exclusive_user_regex
+
+
+class ApplicationServiceWorkerStore(SQLBaseStore):
+    def __init__(self, db_conn, hs):
         self.services_cache = load_appservices(
             hs.hostname,
             hs.config.app_service_config_files
         )
+        self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
+
+        super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs)
 
     def get_app_services(self):
         return self.services_cache
 
     def get_if_app_services_interested_in_user(self, user_id):
-        """Check if the user is one associated with an app service
+        """Check if the user is one associated with an app service (exclusively)
         """
-        for service in self.services_cache:
-            if service.is_interested_in_user(user_id):
-                return True
-        return False
+        if self.exclusive_user_regex:
+            return bool(self.exclusive_user_regex.match(user_id))
+        else:
+            return False
 
     def get_app_service_by_user_id(self, user_id):
         """Retrieve an application service from their user ID.
@@ -78,83 +98,30 @@ class ApplicationServiceStore(SQLBaseStore):
                 return service
         return None
 
-    def get_app_service_rooms(self, service):
-        """Get a list of RoomsForUser for this application service.
-
-        Application services may be "interested" in lots of rooms depending on
-        the room ID, the room aliases, or the members in the room. This function
-        takes all of these into account and returns a list of RoomsForUser which
-        represent the entire list of room IDs that this application service
-        wants to know about.
+    def get_app_service_by_id(self, as_id):
+        """Get the application service with the given appservice ID.
 
         Args:
-            service: The application service to get a room list for.
+            as_id (str): The application service ID.
         Returns:
-            A list of RoomsForUser.
+            synapse.appservice.ApplicationService or None.
         """
-        return self.runInteraction(
-            "get_app_service_rooms",
-            self._get_app_service_rooms_txn,
-            service,
-        )
-
-    def _get_app_service_rooms_txn(self, txn, service):
-        # get all rooms matching the room ID regex.
-        room_entries = self._simple_select_list_txn(
-            txn=txn, table="rooms", keyvalues=None, retcols=["room_id"]
-        )
-        matching_room_list = set([
-            r["room_id"] for r in room_entries if
-            service.is_interested_in_room(r["room_id"])
-        ])
-
-        # resolve room IDs for matching room alias regex.
-        room_alias_mappings = self._simple_select_list_txn(
-            txn=txn, table="room_aliases", keyvalues=None,
-            retcols=["room_id", "room_alias"]
-        )
-        matching_room_list |= set([
-            r["room_id"] for r in room_alias_mappings if
-            service.is_interested_in_alias(r["room_alias"])
-        ])
-
-        # get all rooms for every user for this AS. This is scoped to users on
-        # this HS only.
-        user_list = self._simple_select_list_txn(
-            txn=txn, table="users", keyvalues=None, retcols=["name"]
-        )
-        user_list = [
-            u["name"] for u in user_list if
-            service.is_interested_in_user(u["name"])
-        ]
-        rooms_for_user_matching_user_id = set()  # RoomsForUser list
-        for user_id in user_list:
-            # FIXME: This assumes this store is linked with RoomMemberStore :(
-            rooms_for_user = self._get_rooms_for_user_where_membership_is_txn(
-                txn=txn,
-                user_id=user_id,
-                membership_list=[Membership.JOIN]
-            )
-            rooms_for_user_matching_user_id |= set(rooms_for_user)
-
-        # make RoomsForUser tuples for room ids and aliases which are not in the
-        # main rooms_for_user_list - e.g. they are rooms which do not have AS
-        # registered users in it.
-        known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id]
-        missing_rooms_for_user = [
-            RoomsForUser(r, service.sender, "join") for r in
-            matching_room_list if r not in known_room_ids
-        ]
-        rooms_for_user_matching_user_id |= set(missing_rooms_for_user)
-
-        return rooms_for_user_matching_user_id
+        for service in self.services_cache:
+            if service.id == as_id:
+                return service
+        return None
 
 
-class ApplicationServiceTransactionStore(SQLBaseStore):
+class ApplicationServiceStore(ApplicationServiceWorkerStore):
+    # This is currently empty due to there not being any AS storage functions
+    # that can't be run on the workers. Since this may change in future, and
+    # to keep consistency with the other stores, we keep this empty class for
+    # now.
+    pass
 
-    def __init__(self, hs):
-        super(ApplicationServiceTransactionStore, self).__init__(hs)
 
+class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
+                                               EventsWorkerStore):
     @defer.inlineCallbacks
     def get_appservices_by_state(self, state):
         """Get a list of application services based on their state.
@@ -399,3 +366,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
         events = yield self._get_events(event_ids)
 
         defer.returnValue((upper_bound, events))
+
+
+class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
+    # This is currently empty due to there not being any AS storage functions
+    # that can't be run on the workers. Since this may change in future, and
+    # to keep consistency with the other stores, we keep this empty class for
+    # now.
+    pass
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 94b2bcc54a..8af325a9f5 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,13 +12,14 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import synapse.util.async
 
 from ._base import SQLBaseStore
 from . import engines
 
 from twisted.internet import defer
 
-import ujson as json
+import simplejson as json
 import logging
 
 logger = logging.getLogger(__name__)
@@ -79,35 +80,26 @@ class BackgroundUpdateStore(SQLBaseStore):
     BACKGROUND_UPDATE_INTERVAL_MS = 1000
     BACKGROUND_UPDATE_DURATION_MS = 100
 
-    def __init__(self, hs):
-        super(BackgroundUpdateStore, self).__init__(hs)
+    def __init__(self, db_conn, hs):
+        super(BackgroundUpdateStore, self).__init__(db_conn, hs)
         self._background_update_performance = {}
         self._background_update_queue = []
         self._background_update_handlers = {}
-        self._background_update_timer = None
+        self._all_done = False
 
     @defer.inlineCallbacks
     def start_doing_background_updates(self):
-        assert self._background_update_timer is None, \
-            "background updates already running"
-
         logger.info("Starting background schema updates")
 
         while True:
-            sleep = defer.Deferred()
-            self._background_update_timer = self._clock.call_later(
-                self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
-            )
-            try:
-                yield sleep
-            finally:
-                self._background_update_timer = None
+            yield synapse.util.async.sleep(
+                self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
 
             try:
                 result = yield self.do_next_background_update(
                     self.BACKGROUND_UPDATE_DURATION_MS
                 )
-            except:
+            except Exception:
                 logger.exception("Error doing update")
             else:
                 if result is None:
@@ -115,9 +107,41 @@ class BackgroundUpdateStore(SQLBaseStore):
                         "No more background updates to do."
                         " Unscheduling background update task."
                     )
+                    self._all_done = True
                     defer.returnValue(None)
 
     @defer.inlineCallbacks
+    def has_completed_background_updates(self):
+        """Check if all the background updates have completed
+
+        Returns:
+            Deferred[bool]: True if all background updates have completed
+        """
+        # if we've previously determined that there is nothing left to do, that
+        # is easy
+        if self._all_done:
+            defer.returnValue(True)
+
+        # obviously, if we have things in our queue, we're not done.
+        if self._background_update_queue:
+            defer.returnValue(False)
+
+        # otherwise, check if there are updates to be run. This is important,
+        # as we may be running on a worker which doesn't perform the bg updates
+        # itself, but still wants to wait for them to happen.
+        updates = yield self._simple_select_onecol(
+            "background_updates",
+            keyvalues=None,
+            retcol="1",
+            desc="check_background_updates",
+        )
+        if not updates:
+            self._all_done = True
+            defer.returnValue(True)
+
+        defer.returnValue(False)
+
+    @defer.inlineCallbacks
     def do_next_background_update(self, desired_duration_ms):
         """Does some amount of work on the next queued background update
 
@@ -218,8 +242,29 @@ class BackgroundUpdateStore(SQLBaseStore):
         """
         self._background_update_handlers[update_name] = update_handler
 
+    def register_noop_background_update(self, update_name):
+        """Register a noop handler for a background update.
+
+        This is useful when we previously did a background update, but no
+        longer wish to do the update. In this case the background update should
+        be removed from the schema delta files, but there may still be some
+        users who have the background update queued, so this method should
+        also be called to clear the update.
+
+        Args:
+            update_name (str): Name of update
+        """
+        @defer.inlineCallbacks
+        def noop_update(progress, batch_size):
+            yield self._end_background_update(update_name)
+            defer.returnValue(1)
+
+        self.register_background_update_handler(update_name, noop_update)
+
     def register_background_index_update(self, update_name, index_name,
-                                         table, columns, where_clause=None):
+                                         table, columns, where_clause=None,
+                                         unique=False,
+                                         psql_only=False):
         """Helper for store classes to do a background index addition
 
         To use:
@@ -235,48 +280,80 @@ class BackgroundUpdateStore(SQLBaseStore):
             index_name (str): name of index to add
             table (str): table to add index to
             columns (list[str]): columns/expressions to include in index
+            unique (bool): true to make a UNIQUE index
+            psql_only: true to only create this index on psql databases (useful
+                for virtual sqlite tables)
         """
 
-        # if this is postgres, we add the indexes concurrently. Otherwise
-        # we fall back to doing it inline
-        if isinstance(self.database_engine, engines.PostgresEngine):
-            conc = True
-        else:
-            conc = False
-            # We don't use partial indices on SQLite as it wasn't introduced
-            # until 3.8, and wheezy has 3.7
-            where_clause = None
-
-        sql = (
-            "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)"
-            " %(where_clause)s"
-        ) % {
-            "conc": "CONCURRENTLY" if conc else "",
-            "name": index_name,
-            "table": table,
-            "columns": ", ".join(columns),
-            "where_clause": "WHERE " + where_clause if where_clause else ""
-        }
-
-        def create_index_concurrently(conn):
+        def create_index_psql(conn):
             conn.rollback()
             # postgres insists on autocommit for the index
             conn.set_session(autocommit=True)
-            c = conn.cursor()
-            c.execute(sql)
-            conn.set_session(autocommit=False)
 
-        def create_index(conn):
+            try:
+                c = conn.cursor()
+
+                # If a previous attempt to create the index was interrupted,
+                # we may already have a half-built index. Let's just drop it
+                # before trying to create it again.
+
+                sql = "DROP INDEX IF EXISTS %s" % (index_name,)
+                logger.debug("[SQL] %s", sql)
+                c.execute(sql)
+
+                sql = (
+                    "CREATE %(unique)s INDEX CONCURRENTLY %(name)s"
+                    " ON %(table)s"
+                    " (%(columns)s) %(where_clause)s"
+                ) % {
+                    "unique": "UNIQUE" if unique else "",
+                    "name": index_name,
+                    "table": table,
+                    "columns": ", ".join(columns),
+                    "where_clause": "WHERE " + where_clause if where_clause else ""
+                }
+                logger.debug("[SQL] %s", sql)
+                c.execute(sql)
+            finally:
+                conn.set_session(autocommit=False)
+
+        def create_index_sqlite(conn):
+            # Sqlite doesn't support concurrent creation of indexes.
+            #
+            # We don't use partial indices on SQLite as it wasn't introduced
+            # until 3.8, and wheezy and CentOS 7 have 3.7
+            #
+            # We assume that sqlite doesn't give us invalid indices; however
+            # we may still end up with the index existing but the
+            # background_updates not having been recorded if synapse got shut
+            # down at the wrong moment - hance we use IF NOT EXISTS. (SQLite
+            # has supported CREATE TABLE|INDEX IF NOT EXISTS since 3.3.0.)
+            sql = (
+                "CREATE %(unique)s INDEX IF NOT EXISTS %(name)s ON %(table)s"
+                " (%(columns)s)"
+            ) % {
+                "unique": "UNIQUE" if unique else "",
+                "name": index_name,
+                "table": table,
+                "columns": ", ".join(columns),
+            }
+
             c = conn.cursor()
+            logger.debug("[SQL] %s", sql)
             c.execute(sql)
 
+        if isinstance(self.database_engine, engines.PostgresEngine):
+            runner = create_index_psql
+        elif psql_only:
+            runner = None
+        else:
+            runner = create_index_sqlite
+
         @defer.inlineCallbacks
         def updater(progress, batch_size):
-            logger.info("Adding index %s to %s", index_name, table)
-            if conc:
-                yield self.runWithConnection(create_index_concurrently)
-            else:
-                yield self.runWithConnection(create_index)
+            if runner is not None:
+                logger.info("Adding index %s to %s", index_name, table)
+                yield self.runWithConnection(runner)
             yield self._end_background_update(update_name)
             defer.returnValue(1)
 
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 71e5ea112f..7b44dae0fc 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -15,11 +15,14 @@
 
 import logging
 
-from twisted.internet import defer
+from twisted.internet import defer, reactor
 
 from ._base import Cache
 from . import background_updates
 
+from synapse.util.caches import CACHE_SIZE_FACTOR
+
+
 logger = logging.getLogger(__name__)
 
 # Number of msec of granularity to store the user IP 'last seen' time. Smaller
@@ -29,13 +32,14 @@ LAST_SEEN_GRANULARITY = 120 * 1000
 
 
 class ClientIpStore(background_updates.BackgroundUpdateStore):
-    def __init__(self, hs):
+    def __init__(self, db_conn, hs):
         self.client_ip_last_seen = Cache(
             name="client_ip_last_seen",
             keylen=4,
+            max_entries=50000 * CACHE_SIZE_FACTOR,
         )
 
-        super(ClientIpStore, self).__init__(hs)
+        super(ClientIpStore, self).__init__(db_conn, hs)
 
         self.register_background_index_update(
             "user_ips_device_index",
@@ -44,10 +48,26 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
             columns=["user_id", "device_id", "last_seen"],
         )
 
-    @defer.inlineCallbacks
-    def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
-        now = int(self._clock.time_msec())
-        key = (user.to_string(), access_token, ip)
+        self.register_background_index_update(
+            "user_ips_last_seen_index",
+            index_name="user_ips_last_seen",
+            table="user_ips",
+            columns=["user_id", "last_seen"],
+        )
+
+        # (user_id, access_token, ip) -> (user_agent, device_id, last_seen)
+        self._batch_row_update = {}
+
+        self._client_ip_looper = self._clock.looping_call(
+            self._update_client_ips_batch, 5 * 1000
+        )
+        reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch)
+
+    def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
+                         now=None):
+        if not now:
+            now = int(self._clock.time_msec())
+        key = (user_id, access_token, ip)
 
         try:
             last_seen = self.client_ip_last_seen.get(key)
@@ -56,34 +76,48 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
 
         # Rate-limited inserts
         if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
-            defer.returnValue(None)
+            return
 
         self.client_ip_last_seen.prefill(key, now)
 
-        # It's safe not to lock here: a) no unique constraint,
-        # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
-        yield self._simple_upsert(
-            "user_ips",
-            keyvalues={
-                "user_id": user.to_string(),
-                "access_token": access_token,
-                "ip": ip,
-                "user_agent": user_agent,
-                "device_id": device_id,
-            },
-            values={
-                "last_seen": now,
-            },
-            desc="insert_client_ip",
-            lock=False,
+        self._batch_row_update[key] = (user_agent, device_id, now)
+
+    def _update_client_ips_batch(self):
+        to_update = self._batch_row_update
+        self._batch_row_update = {}
+        return self.runInteraction(
+            "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
         )
 
+    def _update_client_ips_batch_txn(self, txn, to_update):
+        self.database_engine.lock_table(txn, "user_ips")
+
+        for entry in to_update.iteritems():
+            (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
+
+            self._simple_upsert_txn(
+                txn,
+                table="user_ips",
+                keyvalues={
+                    "user_id": user_id,
+                    "access_token": access_token,
+                    "ip": ip,
+                    "user_agent": user_agent,
+                    "device_id": device_id,
+                },
+                values={
+                    "last_seen": last_seen,
+                },
+                lock=False,
+            )
+
     @defer.inlineCallbacks
-    def get_last_client_ip_by_device(self, devices):
+    def get_last_client_ip_by_device(self, user_id, device_id):
         """For each device_id listed, give the user_ip it was last seen on
 
         Args:
-            devices (iterable[(str, str)]):  list of (user_id, device_id) pairs
+            user_id (str)
+            device_id (str): If None fetches all devices for the user
 
         Returns:
             defer.Deferred: resolves to a dict, where the keys
@@ -94,6 +128,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
         res = yield self.runInteraction(
             "get_last_client_ip_by_device",
             self._get_last_client_ip_by_device_txn,
+            user_id, device_id,
             retcols=(
                 "user_id",
                 "access_token",
@@ -102,23 +137,37 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
                 "device_id",
                 "last_seen",
             ),
-            devices=devices
         )
 
         ret = {(d["user_id"], d["device_id"]): d for d in res}
+        for key in self._batch_row_update:
+            uid, access_token, ip = key
+            if uid == user_id:
+                user_agent, did, last_seen = self._batch_row_update[key]
+                if not device_id or did == device_id:
+                    ret[(user_id, device_id)] = {
+                        "user_id": user_id,
+                        "access_token": access_token,
+                        "ip": ip,
+                        "user_agent": user_agent,
+                        "device_id": did,
+                        "last_seen": last_seen,
+                    }
         defer.returnValue(ret)
 
     @classmethod
-    def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols):
+    def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols):
         where_clauses = []
         bindings = []
-        for (user_id, device_id) in devices:
-            if device_id is None:
-                where_clauses.append("(user_id = ? AND device_id IS NULL)")
-                bindings.extend((user_id, ))
-            else:
-                where_clauses.append("(user_id = ? AND device_id = ?)")
-                bindings.extend((user_id, device_id))
+        if device_id is None:
+            where_clauses.append("user_id = ?")
+            bindings.extend((user_id, ))
+        else:
+            where_clauses.append("(user_id = ? AND device_id = ?)")
+            bindings.extend((user_id, device_id))
+
+        if not where_clauses:
+            return []
 
         inner_select = (
             "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
@@ -143,3 +192,37 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
 
         txn.execute(sql, bindings)
         return cls.cursor_to_dict(txn)
+
+    @defer.inlineCallbacks
+    def get_user_ip_and_agents(self, user):
+        user_id = user.to_string()
+        results = {}
+
+        for key in self._batch_row_update:
+            uid, access_token, ip = key
+            if uid == user_id:
+                user_agent, _, last_seen = self._batch_row_update[key]
+                results[(access_token, ip)] = (user_agent, last_seen)
+
+        rows = yield self._simple_select_list(
+            table="user_ips",
+            keyvalues={"user_id": user_id},
+            retcols=[
+                "access_token", "ip", "user_agent", "last_seen"
+            ],
+            desc="get_user_ip_and_agents",
+        )
+
+        results.update(
+            ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
+            for row in rows
+        )
+        defer.returnValue(list(
+            {
+                "access_token": access_token,
+                "ip": ip,
+                "user_agent": user_agent,
+                "last_seen": last_seen,
+            }
+            for (access_token, ip), (user_agent, last_seen) in results.iteritems()
+        ))
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index bde3b5cbbc..a879e5bfc1 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -14,12 +14,14 @@
 # limitations under the License.
 
 import logging
-import ujson
+import simplejson
 
 from twisted.internet import defer
 
 from .background_updates import BackgroundUpdateStore
 
+from synapse.util.caches.expiringcache import ExpiringCache
+
 
 logger = logging.getLogger(__name__)
 
@@ -27,8 +29,8 @@ logger = logging.getLogger(__name__)
 class DeviceInboxStore(BackgroundUpdateStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
 
-    def __init__(self, hs):
-        super(DeviceInboxStore, self).__init__(hs)
+    def __init__(self, db_conn, hs):
+        super(DeviceInboxStore, self).__init__(db_conn, hs)
 
         self.register_background_index_update(
             "device_inbox_stream_index",
@@ -42,6 +44,15 @@ class DeviceInboxStore(BackgroundUpdateStore):
             self._background_drop_index_device_inbox,
         )
 
+        # Map of (user_id, device_id) to the last stream_id that has been
+        # deleted up to. This is so that we can no op deletions.
+        self._last_device_delete_cache = ExpiringCache(
+            cache_name="last_device_delete_cache",
+            clock=self._clock,
+            max_len=10000,
+            expiry_ms=30 * 60 * 1000,
+        )
+
     @defer.inlineCallbacks
     def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
                                      remote_messages_by_destination):
@@ -74,7 +85,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
             )
             rows = []
             for destination, edu in remote_messages_by_destination.items():
-                edu_json = ujson.dumps(edu)
+                edu_json = simplejson.dumps(edu)
                 rows.append((destination, stream_id, now_ms, edu_json))
             txn.executemany(sql, rows)
 
@@ -166,8 +177,8 @@ class DeviceInboxStore(BackgroundUpdateStore):
                     " WHERE user_id = ?"
                 )
                 txn.execute(sql, (user_id,))
-                message_json = ujson.dumps(messages_by_device["*"])
-                for row in txn.fetchall():
+                message_json = simplejson.dumps(messages_by_device["*"])
+                for row in txn:
                     # Add the message for all devices for this user on this
                     # server.
                     device = row[0]
@@ -184,11 +195,11 @@ class DeviceInboxStore(BackgroundUpdateStore):
                 # TODO: Maybe this needs to be done in batches if there are
                 # too many local devices for a given user.
                 txn.execute(sql, [user_id] + devices)
-                for row in txn.fetchall():
+                for row in txn:
                     # Only insert into the local inbox if the device exists on
                     # this server
                     device = row[0]
-                    message_json = ujson.dumps(messages_by_device[device])
+                    message_json = simplejson.dumps(messages_by_device[device])
                     messages_json_for_user[device] = message_json
 
             if messages_json_for_user:
@@ -240,9 +251,9 @@ class DeviceInboxStore(BackgroundUpdateStore):
                 user_id, device_id, last_stream_id, current_stream_id, limit
             ))
             messages = []
-            for row in txn.fetchall():
+            for row in txn:
                 stream_pos = row[0]
-                messages.append(ujson.loads(row[1]))
+                messages.append(simplejson.loads(row[1]))
             if len(messages) < limit:
                 stream_pos = current_stream_id
             return (messages, stream_pos)
@@ -251,6 +262,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
             "get_new_messages_for_device", get_new_messages_for_device_txn,
         )
 
+    @defer.inlineCallbacks
     def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
         """
         Args:
@@ -260,6 +272,18 @@ class DeviceInboxStore(BackgroundUpdateStore):
         Returns:
             A deferred that resolves to the number of messages deleted.
         """
+        # If we have cached the last stream id we've deleted up to, we can
+        # check if there is likely to be anything that needs deleting
+        last_deleted_stream_id = self._last_device_delete_cache.get(
+            (user_id, device_id), None
+        )
+        if last_deleted_stream_id:
+            has_changed = self._device_inbox_stream_cache.has_entity_changed(
+                user_id, last_deleted_stream_id
+            )
+            if not has_changed:
+                defer.returnValue(0)
+
         def delete_messages_for_device_txn(txn):
             sql = (
                 "DELETE FROM device_inbox"
@@ -269,10 +293,20 @@ class DeviceInboxStore(BackgroundUpdateStore):
             txn.execute(sql, (user_id, device_id, up_to_stream_id))
             return txn.rowcount
 
-        return self.runInteraction(
+        count = yield self.runInteraction(
             "delete_messages_for_device", delete_messages_for_device_txn
         )
 
+        # Update the cache, ensuring that we only ever increase the value
+        last_deleted_stream_id = self._last_device_delete_cache.get(
+            (user_id, device_id), 0
+        )
+        self._last_device_delete_cache[(user_id, device_id)] = max(
+            last_deleted_stream_id, up_to_stream_id
+        )
+
+        defer.returnValue(count)
+
     def get_all_new_device_messages(self, last_pos, current_pos, limit):
         """
         Args:
@@ -291,22 +325,25 @@ class DeviceInboxStore(BackgroundUpdateStore):
             # we return.
             upper_pos = min(current_pos, last_pos + limit)
             sql = (
-                "SELECT stream_id, user_id"
+                "SELECT max(stream_id), user_id"
                 " FROM device_inbox"
                 " WHERE ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC"
+                " GROUP BY user_id"
             )
             txn.execute(sql, (last_pos, upper_pos))
             rows = txn.fetchall()
 
             sql = (
-                "SELECT stream_id, destination"
+                "SELECT max(stream_id), destination"
                 " FROM device_federation_outbox"
                 " WHERE ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC"
+                " GROUP BY destination"
             )
             txn.execute(sql, (last_pos, upper_pos))
-            rows.extend(txn.fetchall())
+            rows.extend(txn)
+
+            # Order by ascending stream ordering
+            rows.sort()
 
             return rows
 
@@ -323,12 +360,12 @@ class DeviceInboxStore(BackgroundUpdateStore):
         """
         Args:
             destination(str): The name of the remote server.
-            last_stream_id(int): The last position of the device message stream
+            last_stream_id(int|long): The last position of the device message stream
                 that the server sent up to.
-            current_stream_id(int): The current position of the device
+            current_stream_id(int|long): The current position of the device
                 message stream.
         Returns:
-            Deferred ([dict], int): List of messages for the device and where
+            Deferred ([dict], int|long): List of messages for the device and where
                 in the stream the messages got to.
         """
 
@@ -350,9 +387,9 @@ class DeviceInboxStore(BackgroundUpdateStore):
                 destination, last_stream_id, current_stream_id, limit
             ))
             messages = []
-            for row in txn.fetchall():
+            for row in txn:
                 stream_pos = row[0]
-                messages.append(ujson.loads(row[1]))
+                messages.append(simplejson.loads(row[1]))
             if len(messages) < limit:
                 stream_pos = current_stream_id
             return (messages, stream_pos)
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 8e17800364..712106b83a 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -13,24 +13,41 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import ujson as json
+import simplejson as json
 
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, Cache
+from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+
 
 logger = logging.getLogger(__name__)
 
 
 class DeviceStore(SQLBaseStore):
-    def __init__(self, hs):
-        super(DeviceStore, self).__init__(hs)
+    def __init__(self, db_conn, hs):
+        super(DeviceStore, self).__init__(db_conn, hs)
+
+        # Map of (user_id, device_id) -> bool. If there is an entry that implies
+        # the device exists.
+        self.device_id_exists_cache = Cache(
+            name="device_id_exists",
+            keylen=2,
+            max_entries=10000,
+        )
 
         self._clock.looping_call(
             self._prune_old_outbound_device_pokes, 60 * 60 * 1000
         )
 
+        self.register_background_index_update(
+            "device_lists_stream_idx",
+            index_name="device_lists_stream_user_id",
+            table="device_lists_stream",
+            columns=["user_id", "device_id"],
+        )
+
     @defer.inlineCallbacks
     def store_device(self, user_id, device_id,
                      initial_device_display_name):
@@ -45,6 +62,10 @@ class DeviceStore(SQLBaseStore):
             defer.Deferred: boolean whether the device was inserted or an
                 existing device existed with that ID.
         """
+        key = (user_id, device_id)
+        if self.device_id_exists_cache.get(key, None):
+            defer.returnValue(False)
+
         try:
             inserted = yield self._simple_insert(
                 "devices",
@@ -56,6 +77,7 @@ class DeviceStore(SQLBaseStore):
                 desc="store_device",
                 or_ignore=True,
             )
+            self.device_id_exists_cache.prefill(key, True)
             defer.returnValue(inserted)
         except Exception as e:
             logger.error("store_device with device_id=%s(%r) user_id=%s(%r)"
@@ -84,6 +106,7 @@ class DeviceStore(SQLBaseStore):
             desc="get_device",
         )
 
+    @defer.inlineCallbacks
     def delete_device(self, user_id, device_id):
         """Delete a device.
 
@@ -93,12 +116,34 @@ class DeviceStore(SQLBaseStore):
         Returns:
             defer.Deferred
         """
-        return self._simple_delete_one(
+        yield self._simple_delete_one(
             table="devices",
             keyvalues={"user_id": user_id, "device_id": device_id},
             desc="delete_device",
         )
 
+        self.device_id_exists_cache.invalidate((user_id, device_id))
+
+    @defer.inlineCallbacks
+    def delete_devices(self, user_id, device_ids):
+        """Deletes several devices.
+
+        Args:
+            user_id (str): The ID of the user which owns the devices
+            device_ids (list): The IDs of the devices to delete
+        Returns:
+            defer.Deferred
+        """
+        yield self._simple_delete_many(
+            table="devices",
+            column="device_id",
+            iterable=device_ids,
+            keyvalues={"user_id": user_id},
+            desc="delete_devices",
+        )
+        for device_id in device_ids:
+            self.device_id_exists_cache.invalidate((user_id, device_id))
+
     def update_device(self, user_id, device_id, new_display_name=None):
         """Update a device.
 
@@ -144,6 +189,7 @@ class DeviceStore(SQLBaseStore):
 
         defer.returnValue({d["device_id"]: d for d in devices})
 
+    @cached(max_entries=10000)
     def get_device_list_last_stream_id_for_remote(self, user_id):
         """Get the last stream_id we got for a user. May be None if we haven't
         got any information for them.
@@ -156,16 +202,36 @@ class DeviceStore(SQLBaseStore):
             allow_none=True,
         )
 
+    @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
+                list_name="user_ids", inlineCallbacks=True)
+    def get_device_list_last_stream_id_for_remotes(self, user_ids):
+        rows = yield self._simple_select_many_batch(
+            table="device_lists_remote_extremeties",
+            column="user_id",
+            iterable=user_ids,
+            retcols=("user_id", "stream_id",),
+            desc="get_user_devices_from_cache",
+        )
+
+        results = {user_id: None for user_id in user_ids}
+        results.update({
+            row["user_id"]: row["stream_id"] for row in rows
+        })
+
+        defer.returnValue(results)
+
+    @defer.inlineCallbacks
     def mark_remote_user_device_list_as_unsubscribed(self, user_id):
         """Mark that we no longer track device lists for remote user.
         """
-        return self._simple_delete(
+        yield self._simple_delete(
             table="device_lists_remote_extremeties",
             keyvalues={
                 "user_id": user_id,
             },
             desc="mark_remote_user_device_list_as_unsubscribed",
         )
+        self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
 
     def update_remote_device_list_cache_entry(self, user_id, device_id, content,
                                               stream_id):
@@ -191,6 +257,12 @@ class DeviceStore(SQLBaseStore):
             }
         )
 
+        txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
+        txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
+        txn.call_after(
+            self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
+        )
+
         self._simple_upsert_txn(
             txn,
             table="device_lists_remote_extremeties",
@@ -234,6 +306,12 @@ class DeviceStore(SQLBaseStore):
             ]
         )
 
+        txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
+        txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,))
+        txn.call_after(
+            self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
+        )
+
         self._simple_upsert_txn(
             txn,
             table="device_lists_remote_extremeties",
@@ -249,7 +327,7 @@ class DeviceStore(SQLBaseStore):
         """Get stream of updates to send to remote servers
 
         Returns:
-            (now_stream_id, [ { updates }, .. ])
+            (int, list[dict]): current stream id and list of updates
         """
         now_stream_id = self._device_list_id_gen.get_current_token()
 
@@ -270,24 +348,27 @@ class DeviceStore(SQLBaseStore):
             SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
             WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
             GROUP BY user_id, device_id
+            LIMIT 20
         """
         txn.execute(
             sql, (destination, from_stream_id, now_stream_id, False)
         )
-        rows = txn.fetchall()
 
-        if not rows:
+        # maps (user_id, device_id) -> stream_id
+        query_map = {(r[0], r[1]): r[2] for r in txn}
+        if not query_map:
             return (now_stream_id, [])
 
-        # maps (user_id, device_id) -> stream_id
-        query_map = {(r[0], r[1]): r[2] for r in rows}
+        if len(query_map) >= 20:
+            now_stream_id = max(stream_id for stream_id in query_map.itervalues())
+
         devices = self._get_e2e_device_keys_txn(
             txn, query_map.keys(), include_all_devices=True
         )
 
         prev_sent_id_sql = """
             SELECT coalesce(max(stream_id), 0) as stream_id
-            FROM device_lists_outbound_pokes
+            FROM device_lists_outbound_last_success
             WHERE destination = ? AND user_id = ? AND stream_id <= ?
         """
 
@@ -320,6 +401,7 @@ class DeviceStore(SQLBaseStore):
 
         return (now_stream_id, results)
 
+    @defer.inlineCallbacks
     def get_user_devices_from_cache(self, query_list):
         """Get the devices (and keys if any) for remote users from the cache.
 
@@ -332,27 +414,11 @@ class DeviceStore(SQLBaseStore):
             a set of user_ids and results_map is a mapping of
             user_id -> device_id -> device_info
         """
-        return self.runInteraction(
-            "get_user_devices_from_cache", self._get_user_devices_from_cache_txn,
-            query_list,
+        user_ids = set(user_id for user_id, _ in query_list)
+        user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+        user_ids_in_cache = set(
+            user_id for user_id, stream_id in user_map.items() if stream_id
         )
-
-    def _get_user_devices_from_cache_txn(self, txn, query_list):
-        user_ids = {user_id for user_id, _ in query_list}
-
-        user_ids_in_cache = set()
-        for user_id in user_ids:
-            stream_ids = self._simple_select_onecol_txn(
-                txn,
-                table="device_lists_remote_extremeties",
-                keyvalues={
-                    "user_id": user_id,
-                },
-                retcol="stream_id",
-            )
-            if stream_ids:
-                user_ids_in_cache.add(user_id)
-
         user_ids_not_in_cache = user_ids - user_ids_in_cache
 
         results = {}
@@ -361,32 +427,40 @@ class DeviceStore(SQLBaseStore):
                 continue
 
             if device_id:
-                content = self._simple_select_one_onecol_txn(
-                    txn,
-                    table="device_lists_remote_cache",
-                    keyvalues={
-                        "user_id": user_id,
-                        "device_id": device_id,
-                    },
-                    retcol="content",
-                )
-                results.setdefault(user_id, {})[device_id] = json.loads(content)
+                device = yield self._get_cached_user_device(user_id, device_id)
+                results.setdefault(user_id, {})[device_id] = device
             else:
-                devices = self._simple_select_list_txn(
-                    txn,
-                    table="device_lists_remote_cache",
-                    keyvalues={
-                        "user_id": user_id,
-                    },
-                    retcols=("device_id", "content"),
-                )
-                results[user_id] = {
-                    device["device_id"]: json.loads(device["content"])
-                    for device in devices
-                }
-                user_ids_in_cache.discard(user_id)
+                results[user_id] = yield self._get_cached_devices_for_user(user_id)
 
-        return user_ids_not_in_cache, results
+        defer.returnValue((user_ids_not_in_cache, results))
+
+    @cachedInlineCallbacks(num_args=2, tree=True)
+    def _get_cached_user_device(self, user_id, device_id):
+        content = yield self._simple_select_one_onecol(
+            table="device_lists_remote_cache",
+            keyvalues={
+                "user_id": user_id,
+                "device_id": device_id,
+            },
+            retcol="content",
+            desc="_get_cached_user_device",
+        )
+        defer.returnValue(json.loads(content))
+
+    @cachedInlineCallbacks()
+    def _get_cached_devices_for_user(self, user_id):
+        devices = yield self._simple_select_list(
+            table="device_lists_remote_cache",
+            keyvalues={
+                "user_id": user_id,
+            },
+            retcols=("device_id", "content"),
+            desc="_get_cached_devices_for_user",
+        )
+        defer.returnValue({
+            device["device_id"]: json.loads(device["content"])
+            for device in devices
+        })
 
     def get_devices_with_keys_by_user(self, user_id):
         """Get all devices (with any device keys) for a user
@@ -436,32 +510,43 @@ class DeviceStore(SQLBaseStore):
         )
 
     def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
-        # First we DELETE all rows such that only the latest row for each
-        # (destination, user_id is left. We do this by selecting first and
-        # deleting.
+        # We update the device_lists_outbound_last_success with the successfully
+        # poked users. We do the join to see which users need to be inserted and
+        # which updated.
         sql = """
-            SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes
-            WHERE destination = ? AND stream_id <= ?
+            SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
+            FROM device_lists_outbound_pokes as o
+            LEFT JOIN device_lists_outbound_last_success as s
+                USING (destination, user_id)
+            WHERE destination = ? AND o.stream_id <= ?
             GROUP BY user_id
-            HAVING count(*) > 1
         """
         txn.execute(sql, (destination, stream_id,))
         rows = txn.fetchall()
 
         sql = """
-            DELETE FROM device_lists_outbound_pokes
-            WHERE destination = ? AND user_id = ? AND stream_id < ?
+            UPDATE device_lists_outbound_last_success
+            SET stream_id = ?
+            WHERE destination = ? AND user_id = ?
+        """
+        txn.executemany(
+            sql, ((row[1], destination, row[0],) for row in rows if row[2])
+        )
+
+        sql = """
+            INSERT INTO device_lists_outbound_last_success
+            (destination, user_id, stream_id) VALUES (?, ?, ?)
         """
         txn.executemany(
-            sql, ((destination, row[0], row[1],) for row in rows)
+            sql, ((destination, row[0], row[1],) for row in rows if not row[2])
         )
 
-        # Mark everything that is left as sent
+        # Delete all sent outbound pokes
         sql = """
-            UPDATE device_lists_outbound_pokes SET sent = ?
+            DELETE FROM device_lists_outbound_pokes
             WHERE destination = ? AND stream_id <= ?
         """
-        txn.execute(sql, (True, destination, stream_id,))
+        txn.execute(sql, (destination, stream_id,))
 
     @defer.inlineCallbacks
     def get_user_whose_devices_changed(self, from_key):
@@ -473,12 +558,12 @@ class DeviceStore(SQLBaseStore):
             defer.returnValue(set(changed))
 
         sql = """
-            SELECT user_id FROM device_lists_stream WHERE stream_id > ?
+            SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
         """
         rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
         defer.returnValue(set(row[0] for row in rows))
 
-    def get_all_device_list_changes_for_remotes(self, from_key):
+    def get_all_device_list_changes_for_remotes(self, from_key, to_key):
         """Return a list of `(stream_id, user_id, destination)` which is the
         combined list of changes to devices, and which destinations need to be
         poked. `destination` may be None if no destinations need to be poked.
@@ -486,11 +571,11 @@ class DeviceStore(SQLBaseStore):
         sql = """
             SELECT stream_id, user_id, destination FROM device_lists_stream
             LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
-            WHERE stream_id > ?
+            WHERE ? < stream_id AND stream_id <= ?
         """
         return self._execute(
-            "get_users_and_hosts_device_list", None,
-            sql, from_key,
+            "get_all_device_list_changes_for_remotes", None,
+            sql, from_key, to_key
         )
 
     @defer.inlineCallbacks
@@ -518,6 +603,16 @@ class DeviceStore(SQLBaseStore):
                 host, stream_id,
             )
 
+        # Delete older entries in the table, as we really only care about
+        # when the latest change happened.
+        txn.executemany(
+            """
+            DELETE FROM device_lists_stream
+            WHERE user_id = ? AND device_id = ? AND stream_id < ?
+            """,
+            [(user_id, device_id, stream_id) for device_id in device_ids]
+        )
+
         self._simple_insert_many_txn(
             txn,
             table="device_lists_stream",
@@ -586,6 +681,14 @@ class DeviceStore(SQLBaseStore):
                 )
             )
 
+            # Since we've deleted unsent deltas, we need to remove the entry
+            # of last successful sent so that the prev_ids are correctly set.
+            sql = """
+                DELETE FROM device_lists_outbound_last_success
+                WHERE destination = ? AND user_id = ?
+            """
+            txn.executemany(sql, ((row[0], row[1]) for row in rows))
+
             logger.info("Pruned %d device list outbound pokes", txn.rowcount)
 
         return self.runInteraction(
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 9caaf81f2c..d0c0059757 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -29,8 +29,7 @@ RoomAliasMapping = namedtuple(
 )
 
 
-class DirectoryStore(SQLBaseStore):
-
+class DirectoryWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_association_from_room_alias(self, room_alias):
         """ Get's the room_id and server list for a given room_alias
@@ -69,6 +68,28 @@ class DirectoryStore(SQLBaseStore):
             RoomAliasMapping(room_id, room_alias.to_string(), servers)
         )
 
+    def get_room_alias_creator(self, room_alias):
+        return self._simple_select_one_onecol(
+            table="room_aliases",
+            keyvalues={
+                "room_alias": room_alias,
+            },
+            retcol="creator",
+            desc="get_room_alias_creator",
+            allow_none=True
+        )
+
+    @cached(max_entries=5000)
+    def get_aliases_for_room(self, room_id):
+        return self._simple_select_onecol(
+            "room_aliases",
+            {"room_id": room_id},
+            "room_alias",
+            desc="get_aliases_for_room",
+        )
+
+
+class DirectoryStore(DirectoryWorkerStore):
     @defer.inlineCallbacks
     def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
         """ Creates an associatin between  a room alias and room_id/servers
@@ -116,17 +137,6 @@ class DirectoryStore(SQLBaseStore):
             )
         defer.returnValue(ret)
 
-    def get_room_alias_creator(self, room_alias):
-        return self._simple_select_one_onecol(
-            table="room_aliases",
-            keyvalues={
-                "room_alias": room_alias,
-            },
-            retcol="creator",
-            desc="get_room_alias_creator",
-            allow_none=True
-        )
-
     @defer.inlineCallbacks
     def delete_room_alias(self, room_alias):
         room_id = yield self.runInteraction(
@@ -135,7 +145,6 @@ class DirectoryStore(SQLBaseStore):
             room_alias,
         )
 
-        self.get_aliases_for_room.invalidate((room_id,))
         defer.returnValue(room_id)
 
     def _delete_room_alias_txn(self, txn, room_alias):
@@ -160,13 +169,22 @@ class DirectoryStore(SQLBaseStore):
             (room_alias.to_string(),)
         )
 
+        self._invalidate_cache_and_stream(
+            txn, self.get_aliases_for_room, (room_id,)
+        )
+
         return room_id
 
-    @cached(max_entries=5000)
-    def get_aliases_for_room(self, room_id):
-        return self._simple_select_onecol(
-            "room_aliases",
-            {"room_id": room_id},
-            "room_alias",
-            desc="get_aliases_for_room",
+    def update_aliases_for_room(self, old_room_id, new_room_id, creator):
+        def _update_aliases_for_room_txn(txn):
+            sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
+            txn.execute(sql, (new_room_id, creator, old_room_id,))
+            self._invalidate_cache_and_stream(
+                txn, self.get_aliases_for_room, (old_room_id,)
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_aliases_for_room, (new_room_id,)
+            )
+        return self.runInteraction(
+            "_update_aliases_for_room_txn", _update_aliases_for_room_txn
         )
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index b9f1365f92..ff8538ddf8 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -14,8 +14,10 @@
 # limitations under the License.
 from twisted.internet import defer
 
+from synapse.util.caches.descriptors import cached
+
 from canonicaljson import encode_canonical_json
-import ujson as json
+import simplejson as json
 
 from ._base import SQLBaseStore
 
@@ -120,26 +122,77 @@ class EndToEndKeyStore(SQLBaseStore):
 
         return result
 
-    def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
+    @defer.inlineCallbacks
+    def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
+        """Retrieve a number of one-time keys for a user
+
+        Args:
+            user_id(str): id of user to get keys for
+            device_id(str): id of device to get keys for
+            key_ids(list[str]): list of key ids (excluding algorithm) to
+                retrieve
+
+        Returns:
+            deferred resolving to Dict[(str, str), str]: map from (algorithm,
+            key_id) to json string for key
+        """
+
+        rows = yield self._simple_select_many_batch(
+            table="e2e_one_time_keys_json",
+            column="key_id",
+            iterable=key_ids,
+            retcols=("algorithm", "key_id", "key_json",),
+            keyvalues={
+                "user_id": user_id,
+                "device_id": device_id,
+            },
+            desc="add_e2e_one_time_keys_check",
+        )
+
+        defer.returnValue({
+            (row["algorithm"], row["key_id"]): row["key_json"] for row in rows
+        })
+
+    @defer.inlineCallbacks
+    def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
+        """Insert some new one time keys for a device. Errors if any of the
+        keys already exist.
+
+        Args:
+            user_id(str): id of user to get keys for
+            device_id(str): id of device to get keys for
+            time_now(long): insertion time to record (ms since epoch)
+            new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
+                (algorithm, key_id, key json)
+        """
+
         def _add_e2e_one_time_keys(txn):
-            for (algorithm, key_id, json_bytes) in key_list:
-                self._simple_upsert_txn(
-                    txn, table="e2e_one_time_keys_json",
-                    keyvalues={
+            # We are protected from race between lookup and insertion due to
+            # a unique constraint. If there is a race of two calls to
+            # `add_e2e_one_time_keys` then they'll conflict and we will only
+            # insert one set.
+            self._simple_insert_many_txn(
+                txn, table="e2e_one_time_keys_json",
+                values=[
+                    {
                         "user_id": user_id,
                         "device_id": device_id,
                         "algorithm": algorithm,
                         "key_id": key_id,
-                    },
-                    values={
                         "ts_added_ms": time_now,
                         "key_json": json_bytes,
                     }
-                )
-        return self.runInteraction(
-            "add_e2e_one_time_keys", _add_e2e_one_time_keys
+                    for algorithm, key_id, json_bytes in new_keys
+                ],
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.count_e2e_one_time_keys, (user_id, device_id,)
+            )
+        yield self.runInteraction(
+            "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
         )
 
+    @cached(max_entries=10000)
     def count_e2e_one_time_keys(self, user_id, device_id):
         """ Count the number of one time keys the server has for a device
         Returns:
@@ -153,7 +206,7 @@ class EndToEndKeyStore(SQLBaseStore):
             )
             txn.execute(sql, (user_id, device_id))
             result = {}
-            for algorithm, key_count in txn.fetchall():
+            for algorithm, key_count in txn:
                 result[algorithm] = key_count
             return result
         return self.runInteraction(
@@ -174,7 +227,7 @@ class EndToEndKeyStore(SQLBaseStore):
                 user_result = result.setdefault(user_id, {})
                 device_result = user_result.setdefault(device_id, {})
                 txn.execute(sql, (user_id, device_id, algorithm))
-                for key_id, key_json in txn.fetchall():
+                for key_id, key_json in txn:
                     device_result[algorithm + ":" + key_id] = key_json
                     delete.append((user_id, device_id, algorithm, key_id))
             sql = (
@@ -184,20 +237,29 @@ class EndToEndKeyStore(SQLBaseStore):
             )
             for user_id, device_id, algorithm, key_id in delete:
                 txn.execute(sql, (user_id, device_id, algorithm, key_id))
+                self._invalidate_cache_and_stream(
+                    txn, self.count_e2e_one_time_keys, (user_id, device_id,)
+                )
             return result
         return self.runInteraction(
             "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
         )
 
-    @defer.inlineCallbacks
     def delete_e2e_keys_by_device(self, user_id, device_id):
-        yield self._simple_delete(
-            table="e2e_device_keys_json",
-            keyvalues={"user_id": user_id, "device_id": device_id},
-            desc="delete_e2e_device_keys_by_device"
-        )
-        yield self._simple_delete(
-            table="e2e_one_time_keys_json",
-            keyvalues={"user_id": user_id, "device_id": device_id},
-            desc="delete_e2e_one_time_keys_by_device"
+        def delete_e2e_keys_by_device_txn(txn):
+            self._simple_delete_txn(
+                txn,
+                table="e2e_device_keys_json",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+            )
+            self._simple_delete_txn(
+                txn,
+                table="e2e_one_time_keys_json",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.count_e2e_one_time_keys, (user_id, device_id,)
+            )
+        return self.runInteraction(
+            "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
         )
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 338b495611..8c868ece75 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -18,6 +18,7 @@ from .postgres import PostgresEngine
 from .sqlite3 import Sqlite3Engine
 
 import importlib
+import platform
 
 
 SUPPORTED_MODULE = {
@@ -31,6 +32,10 @@ def create_engine(database_config):
     engine_class = SUPPORTED_MODULE.get(name, None)
 
     if engine_class:
+        # pypy requires psycopg2cffi rather than psycopg2
+        if (name == "psycopg2" and
+                platform.python_implementation() == "PyPy"):
+            name = "psycopg2cffi"
         module = importlib.import_module(name)
         return engine_class(module, database_config)
 
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a6ae79dfad..8a0386c1a4 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -62,3 +62,9 @@ class PostgresEngine(object):
 
     def lock_table(self, txn, table):
         txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
+
+    def get_next_state_group_id(self, txn):
+        """Returns an int that can be used as a new state_group ID
+        """
+        txn.execute("SELECT nextval('state_group_id_seq')")
+        return txn.fetchone()[0]
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py
index 755c9a1f07..60f0fa7fb3 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite3.py
@@ -16,6 +16,7 @@
 from synapse.storage.prepare_database import prepare_database
 
 import struct
+import threading
 
 
 class Sqlite3Engine(object):
@@ -24,6 +25,11 @@ class Sqlite3Engine(object):
     def __init__(self, database_module, database_config):
         self.module = database_module
 
+        # The current max state_group, or None if we haven't looked
+        # in the DB yet.
+        self._current_state_group_id = None
+        self._current_state_group_id_lock = threading.Lock()
+
     def check_database(self, txn):
         pass
 
@@ -43,6 +49,19 @@ class Sqlite3Engine(object):
     def lock_table(self, txn, table):
         return
 
+    def get_next_state_group_id(self, txn):
+        """Returns an int that can be used as a new state_group ID
+        """
+        # We do application locking here since if we're using sqlite then
+        # we are a single process synapse.
+        with self._current_state_group_id_lock:
+            if self._current_state_group_id is None:
+                txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
+                self._current_state_group_id = txn.fetchone()[0]
+
+            self._current_state_group_id += 1
+            return self._current_state_group_id
+
 
 # Following functions taken from: https://github.com/coleifer/peewee
 
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 256e50dc20..8fbf7ffba7 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -12,50 +12,64 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import random
 
 from twisted.internet import defer
 
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.events import EventsWorkerStore
+from synapse.storage.signatures import SignatureWorkerStore
+
 from synapse.api.errors import StoreError
 from synapse.util.caches.descriptors import cached
 from unpaddedbase64 import encode_base64
 
 import logging
-from Queue import PriorityQueue, Empty
+from six.moves.queue import PriorityQueue, Empty
+
+from six.moves import range
 
 
 logger = logging.getLogger(__name__)
 
 
-class EventFederationStore(SQLBaseStore):
-    """ Responsible for storing and serving up the various graphs associated
-    with an event. Including the main event graph and the auth chains for an
-    event.
+class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
+                                 SQLBaseStore):
+    def get_auth_chain(self, event_ids, include_given=False):
+        """Get auth events for given event_ids. The events *must* be state events.
 
-    Also has methods for getting the front (latest) and back (oldest) edges
-    of the event graphs. These are used to generate the parents for new events
-    and backfilling from another server respectively.
-    """
+        Args:
+            event_ids (list): state events
+            include_given (bool): include the given events in result
 
-    def __init__(self, hs):
-        super(EventFederationStore, self).__init__(hs)
+        Returns:
+            list of events
+        """
+        return self.get_auth_chain_ids(
+            event_ids, include_given=include_given,
+        ).addCallback(self._get_events)
 
-        hs.get_clock().looping_call(
-            self._delete_old_forward_extrem_cache, 60 * 60 * 1000
-        )
+    def get_auth_chain_ids(self, event_ids, include_given=False):
+        """Get auth events for given event_ids. The events *must* be state events.
 
-    def get_auth_chain(self, event_ids):
-        return self.get_auth_chain_ids(event_ids).addCallback(self._get_events)
+        Args:
+            event_ids (list): state events
+            include_given (bool): include the given events in result
 
-    def get_auth_chain_ids(self, event_ids):
+        Returns:
+            list of event_ids
+        """
         return self.runInteraction(
             "get_auth_chain_ids",
             self._get_auth_chain_ids_txn,
-            event_ids
+            event_ids, include_given
         )
 
-    def _get_auth_chain_ids_txn(self, txn, event_ids):
-        results = set()
+    def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
+        if include_given:
+            results = set(event_ids)
+        else:
+            results = set()
 
         base_sql = (
             "SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
@@ -67,14 +81,14 @@ class EventFederationStore(SQLBaseStore):
             front_list = list(front)
             chunks = [
                 front_list[x:x + 100]
-                for x in xrange(0, len(front), 100)
+                for x in range(0, len(front), 100)
             ]
             for chunk in chunks:
                 txn.execute(
                     base_sql % (",".join(["?"] * len(chunk)),),
                     chunk
                 )
-                new_front.update([r[0] for r in txn.fetchall()])
+                new_front.update([r[0] for r in txn])
 
             new_front -= results
 
@@ -110,7 +124,7 @@ class EventFederationStore(SQLBaseStore):
 
         txn.execute(sql, (room_id, False,))
 
-        return dict(txn.fetchall())
+        return dict(txn)
 
     def _get_oldest_events_in_room_txn(self, txn, room_id):
         return self._simple_select_onecol_txn(
@@ -122,7 +136,47 @@ class EventFederationStore(SQLBaseStore):
             retcol="event_id",
         )
 
+    @defer.inlineCallbacks
+    def get_prev_events_for_room(self, room_id):
+        """
+        Gets a subset of the current forward extremities in the given room.
+
+        Limits the result to 10 extremities, so that we can avoid creating
+        events which refer to hundreds of prev_events.
+
+        Args:
+            room_id (str): room_id
+
+        Returns:
+            Deferred[list[(str, dict[str, str], int)]]
+                for each event, a tuple of (event_id, hashes, depth)
+                where *hashes* is a map from algorithm to hash.
+        """
+        res = yield self.get_latest_event_ids_and_hashes_in_room(room_id)
+        if len(res) > 10:
+            # Sort by reverse depth, so we point to the most recent.
+            res.sort(key=lambda a: -a[2])
+
+            # we use half of the limit for the actual most recent events, and
+            # the other half to randomly point to some of the older events, to
+            # make sure that we don't completely ignore the older events.
+            res = res[0:5] + random.sample(res[5:], 5)
+
+        defer.returnValue(res)
+
     def get_latest_event_ids_and_hashes_in_room(self, room_id):
+        """
+        Gets the current forward extremities in the given room
+
+        Args:
+            room_id (str): room_id
+
+        Returns:
+            Deferred[list[(str, dict[str, str], int)]]
+                for each event, a tuple of (event_id, hashes, depth)
+                where *hashes* is a map from algorithm to hash.
+        """
+
         return self.runInteraction(
             "get_latest_event_ids_and_hashes_in_room",
             self._get_latest_event_ids_and_hashes_in_room,
@@ -171,22 +225,6 @@ class EventFederationStore(SQLBaseStore):
             room_id,
         )
 
-    @defer.inlineCallbacks
-    def get_max_depth_of_events(self, event_ids):
-        sql = (
-            "SELECT MAX(depth) FROM events WHERE event_id IN (%s)"
-        ) % (",".join(["?"] * len(event_ids)),)
-
-        rows = yield self._execute(
-            "get_max_depth_of_events", None,
-            sql, *event_ids
-        )
-
-        if rows:
-            defer.returnValue(rows[0][0])
-        else:
-            defer.returnValue(1)
-
     def _get_min_depth_interaction(self, txn, room_id):
         min_depth = self._simple_select_one_onecol_txn(
             txn,
@@ -198,88 +236,6 @@ class EventFederationStore(SQLBaseStore):
 
         return int(min_depth) if min_depth is not None else None
 
-    def _update_min_depth_for_room_txn(self, txn, room_id, depth):
-        min_depth = self._get_min_depth_interaction(txn, room_id)
-
-        do_insert = depth < min_depth if min_depth else True
-
-        if do_insert:
-            self._simple_upsert_txn(
-                txn,
-                table="room_depth",
-                keyvalues={
-                    "room_id": room_id,
-                },
-                values={
-                    "min_depth": depth,
-                },
-            )
-
-    def _handle_mult_prev_events(self, txn, events):
-        """
-        For the given event, update the event edges table and forward and
-        backward extremities tables.
-        """
-        self._simple_insert_many_txn(
-            txn,
-            table="event_edges",
-            values=[
-                {
-                    "event_id": ev.event_id,
-                    "prev_event_id": e_id,
-                    "room_id": ev.room_id,
-                    "is_state": False,
-                }
-                for ev in events
-                for e_id, _ in ev.prev_events
-            ],
-        )
-
-        self._update_backward_extremeties(txn, events)
-
-    def _update_backward_extremeties(self, txn, events):
-        """Updates the event_backward_extremities tables based on the new/updated
-        events being persisted.
-
-        This is called for new events *and* for events that were outliers, but
-        are now being persisted as non-outliers.
-
-        Forward extremities are handled when we first start persisting the events.
-        """
-        events_by_room = {}
-        for ev in events:
-            events_by_room.setdefault(ev.room_id, []).append(ev)
-
-        query = (
-            "INSERT INTO event_backward_extremities (event_id, room_id)"
-            " SELECT ?, ? WHERE NOT EXISTS ("
-            " SELECT 1 FROM event_backward_extremities"
-            " WHERE event_id = ? AND room_id = ?"
-            " )"
-            " AND NOT EXISTS ("
-            " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
-            " AND outlier = ?"
-            " )"
-        )
-
-        txn.executemany(query, [
-            (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
-            for ev in events for e_id, _ in ev.prev_events
-            if not ev.internal_metadata.is_outlier()
-        ])
-
-        query = (
-            "DELETE FROM event_backward_extremities"
-            " WHERE event_id = ? AND room_id = ?"
-        )
-        txn.executemany(
-            query,
-            [
-                (ev.event_id, ev.room_id) for ev in events
-                if not ev.internal_metadata.is_outlier()
-            ]
-        )
-
     def get_forward_extremeties_for_room(self, room_id, stream_ordering):
         """For a given room_id and stream_ordering, return the forward
         extremeties of the room at that point in "time".
@@ -334,36 +290,13 @@ class EventFederationStore(SQLBaseStore):
 
         def get_forward_extremeties_for_room_txn(txn):
             txn.execute(sql, (stream_ordering, room_id))
-            rows = txn.fetchall()
-            return [event_id for event_id, in rows]
+            return [event_id for event_id, in txn]
 
         return self.runInteraction(
             "get_forward_extremeties_for_room",
             get_forward_extremeties_for_room_txn
         )
 
-    def _delete_old_forward_extrem_cache(self):
-        def _delete_old_forward_extrem_cache_txn(txn):
-            # Delete entries older than a month, while making sure we don't delete
-            # the only entries for a room.
-            sql = ("""
-                DELETE FROM stream_ordering_to_exterm
-                WHERE
-                room_id IN (
-                    SELECT room_id
-                    FROM stream_ordering_to_exterm
-                    WHERE stream_ordering > ?
-                ) AND stream_ordering < ?
-            """)
-            txn.execute(
-                sql,
-                (self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
-            )
-        return self.runInteraction(
-            "_delete_old_forward_extrem_cache",
-            _delete_old_forward_extrem_cache_txn
-        )
-
     def get_backfill_events(self, room_id, event_list, limit):
         """Get a list of Events for a given topic that occurred before (and
         including) the events in event_list. Return a list of max size `limit`
@@ -436,7 +369,7 @@ class EventFederationStore(SQLBaseStore):
                 (room_id, event_id, False, limit - len(event_results))
             )
 
-            for row in txn.fetchall():
+            for row in txn:
                 if row[1] not in event_results:
                     queue.put((-row[0], row[1]))
 
@@ -482,7 +415,7 @@ class EventFederationStore(SQLBaseStore):
                     (room_id, event_id, False, limit - len(event_results))
                 )
 
-                for e_id, in txn.fetchall():
+                for e_id, in txn:
                     new_front.add(e_id)
 
             new_front -= earliest_events
@@ -493,6 +426,135 @@ class EventFederationStore(SQLBaseStore):
 
         return event_results
 
+
+class EventFederationStore(EventFederationWorkerStore):
+    """ Responsible for storing and serving up the various graphs associated
+    with an event. Including the main event graph and the auth chains for an
+    event.
+
+    Also has methods for getting the front (latest) and back (oldest) edges
+    of the event graphs. These are used to generate the parents for new events
+    and backfilling from another server respectively.
+    """
+
+    EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
+
+    def __init__(self, db_conn, hs):
+        super(EventFederationStore, self).__init__(db_conn, hs)
+
+        self.register_background_update_handler(
+            self.EVENT_AUTH_STATE_ONLY,
+            self._background_delete_non_state_event_auth,
+        )
+
+        hs.get_clock().looping_call(
+            self._delete_old_forward_extrem_cache, 60 * 60 * 1000
+        )
+
+    def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+        min_depth = self._get_min_depth_interaction(txn, room_id)
+
+        if min_depth and depth >= min_depth:
+            return
+
+        self._simple_upsert_txn(
+            txn,
+            table="room_depth",
+            keyvalues={
+                "room_id": room_id,
+            },
+            values={
+                "min_depth": depth,
+            },
+        )
+
+    def _handle_mult_prev_events(self, txn, events):
+        """
+        For the given event, update the event edges table and forward and
+        backward extremities tables.
+        """
+        self._simple_insert_many_txn(
+            txn,
+            table="event_edges",
+            values=[
+                {
+                    "event_id": ev.event_id,
+                    "prev_event_id": e_id,
+                    "room_id": ev.room_id,
+                    "is_state": False,
+                }
+                for ev in events
+                for e_id, _ in ev.prev_events
+            ],
+        )
+
+        self._update_backward_extremeties(txn, events)
+
+    def _update_backward_extremeties(self, txn, events):
+        """Updates the event_backward_extremities tables based on the new/updated
+        events being persisted.
+
+        This is called for new events *and* for events that were outliers, but
+        are now being persisted as non-outliers.
+
+        Forward extremities are handled when we first start persisting the events.
+        """
+        events_by_room = {}
+        for ev in events:
+            events_by_room.setdefault(ev.room_id, []).append(ev)
+
+        query = (
+            "INSERT INTO event_backward_extremities (event_id, room_id)"
+            " SELECT ?, ? WHERE NOT EXISTS ("
+            " SELECT 1 FROM event_backward_extremities"
+            " WHERE event_id = ? AND room_id = ?"
+            " )"
+            " AND NOT EXISTS ("
+            " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
+            " AND outlier = ?"
+            " )"
+        )
+
+        txn.executemany(query, [
+            (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
+            for ev in events for e_id, _ in ev.prev_events
+            if not ev.internal_metadata.is_outlier()
+        ])
+
+        query = (
+            "DELETE FROM event_backward_extremities"
+            " WHERE event_id = ? AND room_id = ?"
+        )
+        txn.executemany(
+            query,
+            [
+                (ev.event_id, ev.room_id) for ev in events
+                if not ev.internal_metadata.is_outlier()
+            ]
+        )
+
+    def _delete_old_forward_extrem_cache(self):
+        def _delete_old_forward_extrem_cache_txn(txn):
+            # Delete entries older than a month, while making sure we don't delete
+            # the only entries for a room.
+            sql = ("""
+                DELETE FROM stream_ordering_to_exterm
+                WHERE
+                room_id IN (
+                    SELECT room_id
+                    FROM stream_ordering_to_exterm
+                    WHERE stream_ordering > ?
+                ) AND stream_ordering < ?
+            """)
+            txn.execute(
+                sql,
+                (self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
+            )
+        return self.runInteraction(
+            "_delete_old_forward_extrem_cache",
+            _delete_old_forward_extrem_cache_txn
+        )
+
     def clean_room_for_join(self, room_id):
         return self.runInteraction(
             "clean_room_for_join",
@@ -505,3 +567,52 @@ class EventFederationStore(SQLBaseStore):
 
         txn.execute(query, (room_id,))
         txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
+
+    @defer.inlineCallbacks
+    def _background_delete_non_state_event_auth(self, progress, batch_size):
+        def delete_event_auth(txn):
+            target_min_stream_id = progress.get("target_min_stream_id_inclusive")
+            max_stream_id = progress.get("max_stream_id_exclusive")
+
+            if not target_min_stream_id or not max_stream_id:
+                txn.execute("SELECT COALESCE(MIN(stream_ordering), 0) FROM events")
+                rows = txn.fetchall()
+                target_min_stream_id = rows[0][0]
+
+                txn.execute("SELECT COALESCE(MAX(stream_ordering), 0) FROM events")
+                rows = txn.fetchall()
+                max_stream_id = rows[0][0]
+
+            min_stream_id = max_stream_id - batch_size
+
+            sql = """
+                DELETE FROM event_auth
+                WHERE event_id IN (
+                    SELECT event_id FROM events
+                    LEFT JOIN state_events USING (room_id, event_id)
+                    WHERE ? <= stream_ordering AND stream_ordering < ?
+                        AND state_key IS null
+                )
+            """
+
+            txn.execute(sql, (min_stream_id, max_stream_id,))
+
+            new_progress = {
+                "target_min_stream_id_inclusive": target_min_stream_id,
+                "max_stream_id_exclusive": min_stream_id,
+            }
+
+            self._background_update_progress_txn(
+                txn, self.EVENT_AUTH_STATE_ONLY, new_progress
+            )
+
+            return min_stream_id >= target_min_stream_id
+
+        result = yield self.runInteraction(
+            self.EVENT_AUTH_STATE_ONLY, delete_event_auth
+        )
+
+        if not result:
+            yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY)
+
+        defer.returnValue(batch_size)
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 522d0114cb..c22762eb5c 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,131 +14,159 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore, LoggingTransaction
 from twisted.internet import defer
+from synapse.util.async import sleep
 from synapse.util.caches.descriptors import cachedInlineCallbacks
 from synapse.types import RoomStreamToken
 from .stream import lower_bound
 
 import logging
-import ujson as json
+import simplejson as json
 
 logger = logging.getLogger(__name__)
 
 
-class EventPushActionsStore(SQLBaseStore):
-    EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
+DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}]
+DEFAULT_HIGHLIGHT_ACTION = [
+    "notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"}
+]
 
-    def __init__(self, hs):
-        self.stream_ordering_month_ago = None
-        super(EventPushActionsStore, self).__init__(hs)
 
-        self.register_background_index_update(
-            self.EPA_HIGHLIGHT_INDEX,
-            index_name="event_push_actions_u_highlight",
-            table="event_push_actions",
-            columns=["user_id", "stream_ordering"],
-        )
+def _serialize_action(actions, is_highlight):
+    """Custom serializer for actions. This allows us to "compress" common actions.
 
-        self.register_background_index_update(
-            "event_push_actions_highlights_index",
-            index_name="event_push_actions_highlights_index",
-            table="event_push_actions",
-            columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
-            where_clause="highlight=1"
-        )
+    We use the fact that most users have the same actions for notifs (and for
+    highlights).
+    We store these default actions as the empty string rather than the full JSON.
+    Since the empty string isn't valid JSON there is no risk of this clashing with
+    any real JSON actions
+    """
+    if is_highlight:
+        if actions == DEFAULT_HIGHLIGHT_ACTION:
+            return ""  # We use empty string as the column is non-NULL
+    else:
+        if actions == DEFAULT_NOTIF_ACTION:
+            return ""
+    return json.dumps(actions)
 
-    def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
-        """
-        Args:
-            event: the event set actions for
-            tuples: list of tuples of (user_id, actions)
-        """
-        values = []
-        for uid, actions in tuples:
-            values.append({
-                'room_id': event.room_id,
-                'event_id': event.event_id,
-                'user_id': uid,
-                'actions': json.dumps(actions),
-                'stream_ordering': event.internal_metadata.stream_ordering,
-                'topological_ordering': event.depth,
-                'notif': 1,
-                'highlight': 1 if _action_has_highlight(actions) else 0,
-            })
-
-        for uid, __ in tuples:
-            txn.call_after(
-                self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
-                (event.room_id, uid)
-            )
-        self._simple_insert_many_txn(txn, "event_push_actions", values)
 
-    @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
-    def get_unread_event_push_actions_by_room_for_user(
-            self, room_id, user_id, last_read_event_id
-    ):
-        def _get_unread_event_push_actions_by_room(txn):
-            sql = (
-                "SELECT stream_ordering, topological_ordering"
-                " FROM events"
-                " WHERE room_id = ? AND event_id = ?"
-            )
-            txn.execute(
-                sql, (room_id, last_read_event_id)
-            )
-            results = txn.fetchall()
-            if len(results) == 0:
-                return {"notify_count": 0, "highlight_count": 0}
-
-            stream_ordering = results[0][0]
-            topological_ordering = results[0][1]
-            token = RoomStreamToken(
-                topological_ordering, stream_ordering
-            )
+def _deserialize_action(actions, is_highlight):
+    """Custom deserializer for actions. This allows us to "compress" common actions
+    """
+    if actions:
+        return json.loads(actions)
 
-            # First get number of notifications.
-            # We don't need to put a notif=1 clause as all rows always have
-            # notif=1
-            sql = (
-                "SELECT count(*)"
-                " FROM event_push_actions ea"
-                " WHERE"
-                " user_id = ?"
-                " AND room_id = ?"
-                " AND %s"
-            ) % (lower_bound(token, self.database_engine, inclusive=False),)
+    if is_highlight:
+        return DEFAULT_HIGHLIGHT_ACTION
+    else:
+        return DEFAULT_NOTIF_ACTION
 
-            txn.execute(sql, (user_id, room_id))
-            row = txn.fetchone()
-            notify_count = row[0] if row else 0
 
-            # Now get the number of highlights
-            sql = (
-                "SELECT count(*)"
-                " FROM event_push_actions ea"
-                " WHERE"
-                " highlight = 1"
-                " AND user_id = ?"
-                " AND room_id = ?"
-                " AND %s"
-            ) % (lower_bound(token, self.database_engine, inclusive=False),)
+class EventPushActionsWorkerStore(SQLBaseStore):
+    def __init__(self, db_conn, hs):
+        super(EventPushActionsWorkerStore, self).__init__(db_conn, hs)
 
-            txn.execute(sql, (user_id, room_id))
-            row = txn.fetchone()
-            highlight_count = row[0] if row else 0
+        # These get correctly set by _find_stream_orderings_for_times_txn
+        self.stream_ordering_month_ago = None
+        self.stream_ordering_day_ago = None
+
+        cur = LoggingTransaction(
+            db_conn.cursor(),
+            name="_find_stream_orderings_for_times_txn",
+            database_engine=self.database_engine,
+            after_callbacks=[],
+            exception_callbacks=[],
+        )
+        self._find_stream_orderings_for_times_txn(cur)
+        cur.close()
 
-            return {
-                "notify_count": notify_count,
-                "highlight_count": highlight_count,
-            }
+        self.find_stream_orderings_looping_call = self._clock.looping_call(
+            self._find_stream_orderings_for_times, 10 * 60 * 1000
+        )
 
+    @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
+    def get_unread_event_push_actions_by_room_for_user(
+            self, room_id, user_id, last_read_event_id
+    ):
         ret = yield self.runInteraction(
             "get_unread_event_push_actions_by_room",
-            _get_unread_event_push_actions_by_room
+            self._get_unread_counts_by_receipt_txn,
+            room_id, user_id, last_read_event_id
         )
         defer.returnValue(ret)
 
+    def _get_unread_counts_by_receipt_txn(self, txn, room_id, user_id,
+                                          last_read_event_id):
+        sql = (
+            "SELECT stream_ordering, topological_ordering"
+            " FROM events"
+            " WHERE room_id = ? AND event_id = ?"
+        )
+        txn.execute(
+            sql, (room_id, last_read_event_id)
+        )
+        results = txn.fetchall()
+        if len(results) == 0:
+            return {"notify_count": 0, "highlight_count": 0}
+
+        stream_ordering = results[0][0]
+        topological_ordering = results[0][1]
+
+        return self._get_unread_counts_by_pos_txn(
+            txn, room_id, user_id, topological_ordering, stream_ordering
+        )
+
+    def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, topological_ordering,
+                                      stream_ordering):
+        token = RoomStreamToken(
+            topological_ordering, stream_ordering
+        )
+
+        # First get number of notifications.
+        # We don't need to put a notif=1 clause as all rows always have
+        # notif=1
+        sql = (
+            "SELECT count(*)"
+            " FROM event_push_actions ea"
+            " WHERE"
+            " user_id = ?"
+            " AND room_id = ?"
+            " AND %s"
+        ) % (lower_bound(token, self.database_engine, inclusive=False),)
+
+        txn.execute(sql, (user_id, room_id))
+        row = txn.fetchone()
+        notify_count = row[0] if row else 0
+
+        txn.execute("""
+            SELECT notif_count FROM event_push_summary
+            WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
+        """, (room_id, user_id, stream_ordering,))
+        rows = txn.fetchall()
+        if rows:
+            notify_count += rows[0][0]
+
+        # Now get the number of highlights
+        sql = (
+            "SELECT count(*)"
+            " FROM event_push_actions ea"
+            " WHERE"
+            " highlight = 1"
+            " AND user_id = ?"
+            " AND room_id = ?"
+            " AND %s"
+        ) % (lower_bound(token, self.database_engine, inclusive=False),)
+
+        txn.execute(sql, (user_id, room_id))
+        row = txn.fetchone()
+        highlight_count = row[0] if row else 0
+
+        return {
+            "notify_count": notify_count,
+            "highlight_count": highlight_count,
+        }
+
     @defer.inlineCallbacks
     def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering):
         def f(txn):
@@ -146,7 +175,7 @@ class EventPushActionsStore(SQLBaseStore):
                 " stream_ordering >= ? AND stream_ordering <= ?"
             )
             txn.execute(sql, (min_stream_ordering, max_stream_ordering))
-            return [r[0] for r in txn.fetchall()]
+            return [r[0] for r in txn]
         ret = yield self.runInteraction("get_push_action_users_in_range", f)
         defer.returnValue(ret)
 
@@ -176,7 +205,8 @@ class EventPushActionsStore(SQLBaseStore):
             # find rooms that have a read receipt in them and return the next
             # push actions
             sql = (
-                "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions"
+                "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
+                "   ep.highlight "
                 " FROM ("
                 "   SELECT room_id,"
                 "       MAX(topological_ordering) as topological_ordering,"
@@ -217,7 +247,7 @@ class EventPushActionsStore(SQLBaseStore):
         def get_no_receipt(txn):
             sql = (
                 "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                " e.received_ts"
+                "   ep.highlight "
                 " FROM event_push_actions AS ep"
                 " INNER JOIN events AS e USING (room_id, event_id)"
                 " WHERE"
@@ -246,7 +276,7 @@ class EventPushActionsStore(SQLBaseStore):
                 "event_id": row[0],
                 "room_id": row[1],
                 "stream_ordering": row[2],
-                "actions": json.loads(row[3]),
+                "actions": _deserialize_action(row[3], row[4]),
             } for row in after_read_receipt + no_read_receipt
         ]
 
@@ -285,7 +315,7 @@ class EventPushActionsStore(SQLBaseStore):
         def get_after_receipt(txn):
             sql = (
                 "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                "  e.received_ts"
+                "  ep.highlight, e.received_ts"
                 " FROM ("
                 "   SELECT room_id,"
                 "       MAX(topological_ordering) as topological_ordering,"
@@ -327,7 +357,7 @@ class EventPushActionsStore(SQLBaseStore):
         def get_no_receipt(txn):
             sql = (
                 "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                " e.received_ts"
+                "   ep.highlight, e.received_ts"
                 " FROM event_push_actions AS ep"
                 " INNER JOIN events AS e USING (room_id, event_id)"
                 " WHERE"
@@ -357,8 +387,8 @@ class EventPushActionsStore(SQLBaseStore):
                 "event_id": row[0],
                 "room_id": row[1],
                 "stream_ordering": row[2],
-                "actions": json.loads(row[3]),
-                "received_ts": row[4],
+                "actions": _deserialize_action(row[3], row[4]),
+                "received_ts": row[5],
             } for row in after_read_receipt + no_read_receipt
         ]
 
@@ -371,6 +401,290 @@ class EventPushActionsStore(SQLBaseStore):
         # Now return the first `limit`
         defer.returnValue(notifs[:limit])
 
+    def add_push_actions_to_staging(self, event_id, user_id_actions):
+        """Add the push actions for the event to the push action staging area.
+
+        Args:
+            event_id (str)
+            user_id_actions (dict[str, list[dict|str])]): A dictionary mapping
+                user_id to list of push actions, where an action can either be
+                a string or dict.
+
+        Returns:
+            Deferred
+        """
+
+        if not user_id_actions:
+            return
+
+        # This is a helper function for generating the necessary tuple that
+        # can be used to inert into the `event_push_actions_staging` table.
+        def _gen_entry(user_id, actions):
+            is_highlight = 1 if _action_has_highlight(actions) else 0
+            return (
+                event_id,  # event_id column
+                user_id,  # user_id column
+                _serialize_action(actions, is_highlight),  # actions column
+                1,  # notif column
+                is_highlight,  # highlight column
+            )
+
+        def _add_push_actions_to_staging_txn(txn):
+            # We don't use _simple_insert_many here to avoid the overhead
+            # of generating lists of dicts.
+
+            sql = """
+                INSERT INTO event_push_actions_staging
+                    (event_id, user_id, actions, notif, highlight)
+                VALUES (?, ?, ?, ?, ?)
+            """
+
+            txn.executemany(sql, (
+                _gen_entry(user_id, actions)
+                for user_id, actions in user_id_actions.iteritems()
+            ))
+
+        return self.runInteraction(
+            "add_push_actions_to_staging", _add_push_actions_to_staging_txn
+        )
+
+    @defer.inlineCallbacks
+    def remove_push_actions_from_staging(self, event_id):
+        """Called if we failed to persist the event to ensure that stale push
+        actions don't build up in the DB
+
+        Args:
+            event_id (str)
+        """
+
+        try:
+            res = yield self._simple_delete(
+                table="event_push_actions_staging",
+                keyvalues={
+                    "event_id": event_id,
+                },
+                desc="remove_push_actions_from_staging",
+            )
+            defer.returnValue(res)
+        except Exception:
+            # this method is called from an exception handler, so propagating
+            # another exception here really isn't helpful - there's nothing
+            # the caller can do about it. Just log the exception and move on.
+            logger.exception(
+                "Error removing push actions after event persistence failure",
+            )
+
+    @defer.inlineCallbacks
+    def _find_stream_orderings_for_times(self):
+        yield self.runInteraction(
+            "_find_stream_orderings_for_times",
+            self._find_stream_orderings_for_times_txn
+        )
+
+    def _find_stream_orderings_for_times_txn(self, txn):
+        logger.info("Searching for stream ordering 1 month ago")
+        self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
+            txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
+        )
+        logger.info(
+            "Found stream ordering 1 month ago: it's %d",
+            self.stream_ordering_month_ago
+        )
+        logger.info("Searching for stream ordering 1 day ago")
+        self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
+            txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
+        )
+        logger.info(
+            "Found stream ordering 1 day ago: it's %d",
+            self.stream_ordering_day_ago
+        )
+
+    def find_first_stream_ordering_after_ts(self, ts):
+        """Gets the stream ordering corresponding to a given timestamp.
+
+        Specifically, finds the stream_ordering of the first event that was
+        received on or after the timestamp. This is done by a binary search on
+        the events table, since there is no index on received_ts, so is
+        relatively slow.
+
+        Args:
+            ts (int): timestamp in millis
+
+        Returns:
+            Deferred[int]: stream ordering of the first event received on/after
+                the timestamp
+        """
+        return self.runInteraction(
+            "_find_first_stream_ordering_after_ts_txn",
+            self._find_first_stream_ordering_after_ts_txn,
+            ts,
+        )
+
+    @staticmethod
+    def _find_first_stream_ordering_after_ts_txn(txn, ts):
+        """
+        Find the stream_ordering of the first event that was received on or
+        after a given timestamp. This is relatively slow as there is no index
+        on received_ts but we can then use this to delete push actions before
+        this.
+
+        received_ts must necessarily be in the same order as stream_ordering
+        and stream_ordering is indexed, so we manually binary search using
+        stream_ordering
+
+        Args:
+            txn (twisted.enterprise.adbapi.Transaction):
+            ts (int): timestamp to search for
+
+        Returns:
+            int: stream ordering
+        """
+        txn.execute("SELECT MAX(stream_ordering) FROM events")
+        max_stream_ordering = txn.fetchone()[0]
+
+        if max_stream_ordering is None:
+            return 0
+
+        # We want the first stream_ordering in which received_ts is greater
+        # than or equal to ts. Call this point X.
+        #
+        # We maintain the invariants:
+        #
+        #   range_start <= X <= range_end
+        #
+        range_start = 0
+        range_end = max_stream_ordering + 1
+
+        # Given a stream_ordering, look up the timestamp at that
+        # stream_ordering.
+        #
+        # The array may be sparse (we may be missing some stream_orderings).
+        # We treat the gaps as the same as having the same value as the
+        # preceding entry, because we will pick the lowest stream_ordering
+        # which satisfies our requirement of received_ts >= ts.
+        #
+        # For example, if our array of events indexed by stream_ordering is
+        # [10, <none>, 20], we should treat this as being equivalent to
+        # [10, 10, 20].
+        #
+        sql = (
+            "SELECT received_ts FROM events"
+            " WHERE stream_ordering <= ?"
+            " ORDER BY stream_ordering DESC"
+            " LIMIT 1"
+        )
+
+        while range_end - range_start > 0:
+            middle = (range_end + range_start) // 2
+            txn.execute(sql, (middle,))
+            row = txn.fetchone()
+            if row is None:
+                # no rows with stream_ordering<=middle
+                range_start = middle + 1
+                continue
+
+            middle_ts = row[0]
+            if ts > middle_ts:
+                # we got a timestamp lower than the one we were looking for.
+                # definitely need to look higher: X > middle.
+                range_start = middle + 1
+            else:
+                # we got a timestamp higher than (or the same as) the one we
+                # were looking for. We aren't yet sure about the point we
+                # looked up, but we can be sure that X <= middle.
+                range_end = middle
+
+        return range_end
+
+
+class EventPushActionsStore(EventPushActionsWorkerStore):
+    EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
+
+    def __init__(self, db_conn, hs):
+        super(EventPushActionsStore, self).__init__(db_conn, hs)
+
+        self.register_background_index_update(
+            self.EPA_HIGHLIGHT_INDEX,
+            index_name="event_push_actions_u_highlight",
+            table="event_push_actions",
+            columns=["user_id", "stream_ordering"],
+        )
+
+        self.register_background_index_update(
+            "event_push_actions_highlights_index",
+            index_name="event_push_actions_highlights_index",
+            table="event_push_actions",
+            columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
+            where_clause="highlight=1"
+        )
+
+        self._doing_notif_rotation = False
+        self._rotate_notif_loop = self._clock.looping_call(
+            self._rotate_notifs, 30 * 60 * 1000
+        )
+
+    def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts,
+                                                  all_events_and_contexts):
+        """Handles moving push actions from staging table to main
+        event_push_actions table for all events in `events_and_contexts`.
+
+        Also ensures that all events in `all_events_and_contexts` are removed
+        from the push action staging area.
+
+        Args:
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+            all_events_and_contexts (list[(EventBase, EventContext)]): all
+                events that we were going to persist. This includes events
+                we've already persisted, etc, that wouldn't appear in
+                events_and_context.
+        """
+
+        sql = """
+            INSERT INTO event_push_actions (
+                room_id, event_id, user_id, actions, stream_ordering,
+                topological_ordering, notif, highlight
+            )
+            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+            FROM event_push_actions_staging
+            WHERE event_id = ?
+        """
+
+        if events_and_contexts:
+            txn.executemany(sql, (
+                (
+                    event.room_id, event.internal_metadata.stream_ordering,
+                    event.depth, event.event_id,
+                )
+                for event, _ in events_and_contexts
+            ))
+
+        for event, _ in events_and_contexts:
+            user_ids = self._simple_select_onecol_txn(
+                txn,
+                table="event_push_actions_staging",
+                keyvalues={
+                    "event_id": event.event_id,
+                },
+                retcol="user_id",
+            )
+
+            for uid in user_ids:
+                txn.call_after(
+                    self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+                    (event.room_id, uid,)
+                )
+
+        # Now we delete the staging area for *all* events that were being
+        # persisted.
+        txn.executemany(
+            "DELETE FROM event_push_actions_staging WHERE event_id = ?",
+            (
+                (event.event_id,)
+                for event, _ in all_events_and_contexts
+            )
+        )
+
     @defer.inlineCallbacks
     def get_push_actions_for_user(self, user_id, before=None, limit=50,
                                   only_highlight=False):
@@ -392,7 +706,7 @@ class EventPushActionsStore(SQLBaseStore):
             sql = (
                 "SELECT epa.event_id, epa.room_id,"
                 " epa.stream_ordering, epa.topological_ordering,"
-                " epa.actions, epa.profile_tag, e.received_ts"
+                " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
                 " FROM event_push_actions epa, events e"
                 " WHERE epa.event_id = e.event_id"
                 " AND epa.user_id = ? %s"
@@ -407,7 +721,7 @@ class EventPushActionsStore(SQLBaseStore):
             "get_push_actions_for_user", f
         )
         for pa in push_actions:
-            pa["actions"] = json.loads(pa["actions"])
+            pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
         defer.returnValue(push_actions)
 
     @defer.inlineCallbacks
@@ -448,7 +762,7 @@ class EventPushActionsStore(SQLBaseStore):
         )
 
     def _remove_old_push_actions_before_txn(self, txn, room_id, user_id,
-                                            topological_ordering):
+                                            topological_ordering, stream_ordering):
         """
         Purges old push actions for a user and room before a given
         topological_ordering.
@@ -479,65 +793,140 @@ class EventPushActionsStore(SQLBaseStore):
         txn.execute(
             "DELETE FROM event_push_actions "
             " WHERE user_id = ? AND room_id = ? AND "
-            " topological_ordering < ?"
+            " topological_ordering <= ?"
             " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
             (user_id, room_id, topological_ordering, self.stream_ordering_month_ago)
         )
 
+        txn.execute("""
+            DELETE FROM event_push_summary
+            WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
+        """, (room_id, user_id, stream_ordering))
+
     @defer.inlineCallbacks
-    def _find_stream_orderings_for_times(self):
-        yield self.runInteraction(
-            "_find_stream_orderings_for_times",
-            self._find_stream_orderings_for_times_txn
-        )
+    def _rotate_notifs(self):
+        if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
+            return
+        self._doing_notif_rotation = True
 
-    def _find_stream_orderings_for_times_txn(self, txn):
-        logger.info("Searching for stream ordering 1 month ago")
-        self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
-            txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
-        )
-        logger.info(
-            "Found stream ordering 1 month ago: it's %d",
-            self.stream_ordering_month_ago
+        try:
+            while True:
+                logger.info("Rotating notifications")
+
+                caught_up = yield self.runInteraction(
+                    "_rotate_notifs",
+                    self._rotate_notifs_txn
+                )
+                if caught_up:
+                    break
+                yield sleep(5)
+        finally:
+            self._doing_notif_rotation = False
+
+    def _rotate_notifs_txn(self, txn):
+        """Archives older notifications into event_push_summary. Returns whether
+        the archiving process has caught up or not.
+        """
+
+        old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
+            txn,
+            table="event_push_summary_stream_ordering",
+            keyvalues={},
+            retcol="stream_ordering",
         )
 
-    def _find_first_stream_ordering_after_ts_txn(self, txn, ts):
-        """
-        Find the stream_ordering of the first event that was received after
-        a given timestamp. This is relatively slow as there is no index on
-        received_ts but we can then use this to delete push actions before
-        this.
+        # We don't to try and rotate millions of rows at once, so we cap the
+        # maximum stream ordering we'll rotate before.
+        txn.execute("""
+            SELECT stream_ordering FROM event_push_actions
+            WHERE stream_ordering > ?
+            ORDER BY stream_ordering ASC LIMIT 1 OFFSET 50000
+        """, (old_rotate_stream_ordering,))
+        stream_row = txn.fetchone()
+        if stream_row:
+            offset_stream_ordering, = stream_row
+            rotate_to_stream_ordering = min(
+                self.stream_ordering_day_ago, offset_stream_ordering
+            )
+            caught_up = offset_stream_ordering >= self.stream_ordering_day_ago
+        else:
+            rotate_to_stream_ordering = self.stream_ordering_day_ago
+            caught_up = True
 
-        received_ts must necessarily be in the same order as stream_ordering
-        and stream_ordering is indexed, so we manually binary search using
-        stream_ordering
+        logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering)
+
+        self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering)
+
+        # We have caught up iff we were limited by `stream_ordering_day_ago`
+        return caught_up
+
+    def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
+        old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
+            txn,
+            table="event_push_summary_stream_ordering",
+            keyvalues={},
+            retcol="stream_ordering",
+        )
+
+        # Calculate the new counts that should be upserted into event_push_summary
+        sql = """
+            SELECT user_id, room_id,
+                coalesce(old.notif_count, 0) + upd.notif_count,
+                upd.stream_ordering,
+                old.user_id
+            FROM (
+                SELECT user_id, room_id, count(*) as notif_count,
+                    max(stream_ordering) as stream_ordering
+                FROM event_push_actions
+                WHERE ? <= stream_ordering AND stream_ordering < ?
+                    AND highlight = 0
+                GROUP BY user_id, room_id
+            ) AS upd
+            LEFT JOIN event_push_summary AS old USING (user_id, room_id)
         """
-        txn.execute("SELECT MAX(stream_ordering) FROM events")
-        max_stream_ordering = txn.fetchone()[0]
 
-        if max_stream_ordering is None:
-            return 0
+        txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering,))
+        rows = txn.fetchall()
+
+        logger.info("Rotating notifications, handling %d rows", len(rows))
+
+        # If the `old.user_id` above is NULL then we know there isn't already an
+        # entry in the table, so we simply insert it. Otherwise we update the
+        # existing table.
+        self._simple_insert_many_txn(
+            txn,
+            table="event_push_summary",
+            values=[
+                {
+                    "user_id": row[0],
+                    "room_id": row[1],
+                    "notif_count": row[2],
+                    "stream_ordering": row[3],
+                }
+                for row in rows if row[4] is None
+            ]
+        )
 
-        range_start = 0
-        range_end = max_stream_ordering
+        txn.executemany(
+            """
+                UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
+                WHERE user_id = ? AND room_id = ?
+            """,
+            ((row[2], row[3], row[0], row[1],) for row in rows if row[4] is not None)
+        )
 
-        sql = (
-            "SELECT received_ts FROM events"
-            " WHERE stream_ordering > ?"
-            " ORDER BY stream_ordering"
-            " LIMIT 1"
+        txn.execute(
+            "DELETE FROM event_push_actions"
+            " WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0",
+            (old_rotate_stream_ordering, rotate_to_stream_ordering,)
         )
 
-        while range_end - range_start > 1:
-            middle = int((range_end + range_start) / 2)
-            txn.execute(sql, (middle,))
-            middle_ts = txn.fetchone()[0]
-            if ts > middle_ts:
-                range_start = middle
-            else:
-                range_end = middle
+        logger.info("Rotating notifications, deleted %s push actions", txn.rowcount)
 
-        return range_end
+        txn.execute(
+            "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?",
+            (rotate_to_stream_ordering,)
+        )
 
 
 def _action_has_highlight(actions):
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index c88f689d3a..05cde96afc 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,59 +13,61 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from ._base import SQLBaseStore
 
-from twisted.internet import defer, reactor
+from collections import OrderedDict, deque, namedtuple
+from functools import wraps
+import itertools
+import logging
 
-from synapse.events import FrozenEvent, USE_FROZEN_DICTS
-from synapse.events.utils import prune_event
+import simplejson as json
+from twisted.internet import defer
 
+from synapse.storage.events_worker import EventsWorkerStore
 from synapse.util.async import ObservableDeferred
+from synapse.util.frozenutils import frozendict_json_encoder
 from synapse.util.logcontext import (
-    preserve_fn, PreserveLoggingContext, preserve_context_over_deferred
+    PreserveLoggingContext, make_deferred_yieldable,
 )
 from synapse.util.logutils import log_function
 from synapse.util.metrics import Measure
 from synapse.api.constants import EventTypes
 from synapse.api.errors import SynapseError
-from synapse.state import resolve_events
-from synapse.util.caches.descriptors import cached
-
-from canonicaljson import encode_canonical_json
-from collections import deque, namedtuple, OrderedDict
-from functools import wraps
-
-import synapse
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.types import get_domain_from_id
 import synapse.metrics
 
-
-import logging
-import math
-import ujson as json
+# these are only included to make the type annotations work
+from synapse.events import EventBase    # noqa: F401
+from synapse.events.snapshot import EventContext   # noqa: F401
 
 logger = logging.getLogger(__name__)
 
 
 metrics = synapse.metrics.get_metrics_for(__name__)
 persist_event_counter = metrics.register_counter("persisted_events")
+event_counter = metrics.register_counter(
+    "persisted_events_sep", labels=["type", "origin_type", "origin_entity"]
+)
 
-
-def encode_json(json_object):
-    if USE_FROZEN_DICTS:
-        # ujson doesn't like frozen_dicts
-        return encode_canonical_json(json_object)
-    else:
-        return json.dumps(json_object, ensure_ascii=False)
+# The number of times we are recalculating the current state
+state_delta_counter = metrics.register_counter(
+    "state_delta",
+)
+# The number of times we are recalculating state when there is only a
+# single forward extremity
+state_delta_single_event_counter = metrics.register_counter(
+    "state_delta_single_event",
+)
+# The number of times we are reculating state when we could have resonably
+# calculated the delta when we calculated the state for an event we were
+# persisting.
+state_delta_reuse_delta_counter = metrics.register_counter(
+    "state_delta_reuse_delta",
+)
 
 
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
-# control how we batch/bulk fetch events from the database.
-# The values are plucked out of thing air to make initial sync run faster
-# on jki.re
-# TODO: Make these configurable.
-EVENT_QUEUE_THREADS = 3  # Max number of threads that will fetch events
-EVENT_QUEUE_ITERATIONS = 3  # No. times we block waiting for requests for events
-EVENT_QUEUE_TIMEOUT_S = 0.1  # Timeout when waiting for requests for events
+def encode_json(json_object):
+    return frozendict_json_encoder.encode(json_object)
 
 
 class _EventPeristenceQueue(object):
@@ -82,15 +85,30 @@ class _EventPeristenceQueue(object):
 
     def add_to_queue(self, room_id, events_and_contexts, backfilled):
         """Add events to the queue, with the given persist_event options.
+
+        NB: due to the normal usage pattern of this method, it does *not*
+        follow the synapse logcontext rules, and leaves the logcontext in
+        place whether or not the returned deferred is ready.
+
+        Args:
+            room_id (str):
+            events_and_contexts (list[(EventBase, EventContext)]):
+            backfilled (bool):
+
+        Returns:
+            defer.Deferred: a deferred which will resolve once the events are
+                persisted. Runs its callbacks *without* a logcontext.
         """
         queue = self._event_persist_queues.setdefault(room_id, deque())
         if queue:
+            # if the last item in the queue has the same `backfilled` setting,
+            # we can just add these new events to that item.
             end_item = queue[-1]
             if end_item.backfilled == backfilled:
                 end_item.events_and_contexts.extend(events_and_contexts)
                 return end_item.deferred.observe()
 
-        deferred = ObservableDeferred(defer.Deferred())
+        deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
 
         queue.append(self._EventPersistQueueItem(
             events_and_contexts=events_and_contexts,
@@ -103,11 +121,11 @@ class _EventPeristenceQueue(object):
     def handle_queue(self, room_id, per_item_callback):
         """Attempts to handle the queue for a room if not already being handled.
 
-        The given callback will be invoked with for each item in the queue,1
+        The given callback will be invoked with for each item in the queue,
         of type _EventPersistQueueItem. The per_item_callback will continuously
         be called with new items, unless the queue becomnes empty. The return
         value of the function will be given to the deferreds waiting on the item,
-        exceptions will be passed to the deferres as well.
+        exceptions will be passed to the deferreds as well.
 
         This function should therefore be called whenever anything is added
         to the queue.
@@ -126,18 +144,25 @@ class _EventPeristenceQueue(object):
             try:
                 queue = self._get_drainining_queue(room_id)
                 for item in queue:
+                    # handle_queue_loop runs in the sentinel logcontext, so
+                    # there is no need to preserve_fn when running the
+                    # callbacks on the deferred.
                     try:
                         ret = yield per_item_callback(item)
                         item.deferred.callback(ret)
-                    except Exception as e:
-                        item.deferred.errback(e)
+                    except Exception:
+                        item.deferred.errback()
             finally:
                 queue = self._event_persist_queues.pop(room_id, None)
                 if queue:
                     self._event_persist_queues[room_id] = queue
                 self._currently_persisting_rooms.discard(room_id)
 
-        preserve_fn(handle_queue_loop)()
+        # set handle_queue_loop off on the background. We don't want to
+        # attribute work done in it to the current request, so we drop the
+        # logcontext altogether.
+        with PreserveLoggingContext():
+            handle_queue_loop()
 
     def _get_drainining_queue(self, room_id):
         queue = self._event_persist_queues.setdefault(room_id, deque())
@@ -173,13 +198,12 @@ def _retry_on_integrity_error(func):
     return f
 
 
-class EventsStore(SQLBaseStore):
+class EventsStore(EventsWorkerStore):
     EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
     EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
 
-    def __init__(self, hs):
-        super(EventsStore, self).__init__(hs)
-        self._clock = hs.get_clock()
+    def __init__(self, db_conn, hs):
+        super(EventsStore, self).__init__(db_conn, hs)
         self.register_background_update_handler(
             self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
         )
@@ -196,8 +220,22 @@ class EventsStore(SQLBaseStore):
             where_clause="contains_url = true AND outlier = false",
         )
 
+        # an event_id index on event_search is useful for the purge_history
+        # api. Plus it means we get to enforce some integrity with a UNIQUE
+        # clause
+        self.register_background_index_update(
+            "event_search_event_id_idx",
+            index_name="event_search_event_id_idx",
+            table="event_search",
+            columns=["event_id"],
+            unique=True,
+            psql_only=True,
+        )
+
         self._event_persist_queue = _EventPeristenceQueue()
 
+        self._state_resolution_handler = hs.get_state_resolution_handler()
+
     def persist_events(self, events_and_contexts, backfilled=False):
         """
         Write events to the database
@@ -210,23 +248,34 @@ class EventsStore(SQLBaseStore):
             partitioned.setdefault(event.room_id, []).append((event, ctx))
 
         deferreds = []
-        for room_id, evs_ctxs in partitioned.items():
-            d = preserve_fn(self._event_persist_queue.add_to_queue)(
+        for room_id, evs_ctxs in partitioned.iteritems():
+            d = self._event_persist_queue.add_to_queue(
                 room_id, evs_ctxs,
                 backfilled=backfilled,
             )
             deferreds.append(d)
 
-        for room_id in partitioned.keys():
+        for room_id in partitioned:
             self._maybe_start_persisting(room_id)
 
-        return preserve_context_over_deferred(
+        return make_deferred_yieldable(
             defer.gatherResults(deferreds, consumeErrors=True)
         )
 
     @defer.inlineCallbacks
     @log_function
     def persist_event(self, event, context, backfilled=False):
+        """
+
+        Args:
+            event (EventBase):
+            context (EventContext):
+            backfilled (bool):
+
+        Returns:
+            Deferred: resolves to (int, int): the stream ordering of ``event``,
+            and the stream ordering of the latest persisted event
+        """
         deferred = self._event_persist_queue.add_to_queue(
             event.room_id, [(event, context)],
             backfilled=backfilled,
@@ -234,7 +283,7 @@ class EventsStore(SQLBaseStore):
 
         self._maybe_start_persisting(event.room_id)
 
-        yield preserve_context_over_deferred(deferred)
+        yield make_deferred_yieldable(deferred)
 
         max_persisted_id = yield self._stream_id_gen.get_current_token()
         defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
@@ -242,10 +291,11 @@ class EventsStore(SQLBaseStore):
     def _maybe_start_persisting(self, room_id):
         @defer.inlineCallbacks
         def persisting_queue(item):
-            yield self._persist_events(
-                item.events_and_contexts,
-                backfilled=item.backfilled,
-            )
+            with Measure(self._clock, "persist_events"):
+                yield self._persist_events(
+                    item.events_and_contexts,
+                    backfilled=item.backfilled,
+                )
 
         self._event_persist_queue.handle_queue(room_id, persisting_queue)
 
@@ -253,6 +303,16 @@ class EventsStore(SQLBaseStore):
     @defer.inlineCallbacks
     def _persist_events(self, events_and_contexts, backfilled=False,
                         delete_existing=False):
+        """Persist events to db
+
+        Args:
+            events_and_contexts (list[(EventBase, EventContext)]):
+            backfilled (bool):
+            delete_existing (bool):
+
+        Returns:
+            Deferred: resolves when the events have been persisted
+        """
         if not events_and_contexts:
             return
 
@@ -282,8 +342,20 @@ class EventsStore(SQLBaseStore):
 
                 # NB: Assumes that we are only persisting events for one room
                 # at a time.
+
+                # map room_id->list[event_ids] giving the new forward
+                # extremities in each room
                 new_forward_extremeties = {}
+
+                # map room_id->(type,state_key)->event_id tracking the full
+                # state in each room after adding these events
                 current_state_for_room = {}
+
+                # map room_id->(to_delete, to_insert) where each entry is
+                # a map (type,key)->event_id giving the state delta in each
+                # room
+                state_delta_for_room = {}
+
                 if not backfilled:
                     with Measure(self._clock, "_calculate_state_and_extrem"):
                         # Work out the new "current state" for each room.
@@ -295,7 +367,7 @@ class EventsStore(SQLBaseStore):
                                 (event, context)
                             )
 
-                        for room_id, ev_ctx_rm in events_by_room.items():
+                        for room_id, ev_ctx_rm in events_by_room.iteritems():
                             # Work out new extremities by recursively adding and removing
                             # the new events.
                             latest_event_ids = yield self.get_latest_event_ids_in_room(
@@ -305,17 +377,64 @@ class EventsStore(SQLBaseStore):
                                 room_id, ev_ctx_rm, latest_event_ids
                             )
 
-                            if new_latest_event_ids == set(latest_event_ids):
+                            latest_event_ids = set(latest_event_ids)
+                            if new_latest_event_ids == latest_event_ids:
                                 # No change in extremities, so no change in state
                                 continue
 
                             new_forward_extremeties[room_id] = new_latest_event_ids
 
-                            state = yield self._calculate_state_delta(
-                                room_id, ev_ctx_rm, new_latest_event_ids
+                            len_1 = (
+                                len(latest_event_ids) == 1
+                                and len(new_latest_event_ids) == 1
+                            )
+                            if len_1:
+                                all_single_prev_not_state = all(
+                                    len(event.prev_events) == 1
+                                    and not event.is_state()
+                                    for event, ctx in ev_ctx_rm
+                                )
+                                # Don't bother calculating state if they're just
+                                # a long chain of single ancestor non-state events.
+                                if all_single_prev_not_state:
+                                    continue
+
+                            state_delta_counter.inc()
+                            if len(new_latest_event_ids) == 1:
+                                state_delta_single_event_counter.inc()
+
+                                # This is a fairly handwavey check to see if we could
+                                # have guessed what the delta would have been when
+                                # processing one of these events.
+                                # What we're interested in is if the latest extremities
+                                # were the same when we created the event as they are
+                                # now. When this server creates a new event (as opposed
+                                # to receiving it over federation) it will use the
+                                # forward extremities as the prev_events, so we can
+                                # guess this by looking at the prev_events and checking
+                                # if they match the current forward extremities.
+                                for ev, _ in ev_ctx_rm:
+                                    prev_event_ids = set(e for e, _ in ev.prev_events)
+                                    if latest_event_ids == prev_event_ids:
+                                        state_delta_reuse_delta_counter.inc()
+                                        break
+
+                            logger.info(
+                                "Calculating state delta for room %s", room_id,
                             )
-                            if state:
-                                current_state_for_room[room_id] = state
+                            current_state = yield self._get_new_state_after_events(
+                                room_id,
+                                ev_ctx_rm,
+                                latest_event_ids,
+                                new_latest_event_ids,
+                            )
+                            if current_state is not None:
+                                current_state_for_room[room_id] = current_state
+                                delta = yield self._calculate_state_delta(
+                                    room_id, current_state,
+                                )
+                                if delta is not None:
+                                    state_delta_for_room[room_id] = delta
 
                 yield self.runInteraction(
                     "persist_events",
@@ -323,10 +442,35 @@ class EventsStore(SQLBaseStore):
                     events_and_contexts=chunk,
                     backfilled=backfilled,
                     delete_existing=delete_existing,
-                    current_state_for_room=current_state_for_room,
+                    state_delta_for_room=state_delta_for_room,
                     new_forward_extremeties=new_forward_extremeties,
                 )
                 persist_event_counter.inc_by(len(chunk))
+                synapse.metrics.event_persisted_position.set(
+                    chunk[-1][0].internal_metadata.stream_ordering,
+                )
+                for event, context in chunk:
+                    if context.app_service:
+                        origin_type = "local"
+                        origin_entity = context.app_service.id
+                    elif self.hs.is_mine_id(event.sender):
+                        origin_type = "local"
+                        origin_entity = "*client*"
+                    else:
+                        origin_type = "remote"
+                        origin_entity = get_domain_from_id(event.sender)
+
+                    event_counter.inc(event.type, origin_type, origin_entity)
+
+                for room_id, new_state in current_state_for_room.iteritems():
+                    self.get_current_state_ids.prefill(
+                        (room_id, ), new_state
+                    )
+
+                for room_id, latest_event_ids in new_forward_extremeties.iteritems():
+                    self.get_latest_event_ids_in_room.prefill(
+                        (room_id,), list(latest_event_ids)
+                    )
 
     @defer.inlineCallbacks
     def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids):
@@ -370,71 +514,137 @@ class EventsStore(SQLBaseStore):
         defer.returnValue(new_latest_event_ids)
 
     @defer.inlineCallbacks
-    def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
-        """Calculate the new state deltas for a room.
+    def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids,
+                                    new_latest_event_ids):
+        """Calculate the current state dict after adding some new events to
+        a room
 
-        Assumes that we are only persisting events for one room at a time.
+        Args:
+            room_id (str):
+                room to which the events are being added. Used for logging etc
+
+            events_context (list[(EventBase, EventContext)]):
+                events and contexts which are being added to the room
+
+            old_latest_event_ids (iterable[str]):
+                the old forward extremities for the room.
+
+            new_latest_event_ids (iterable[str]):
+                the new forward extremities for the room.
 
         Returns:
-            2-tuple (to_delete, to_insert) where both are state dicts, i.e.
-            (type, state_key) -> event_id. `to_delete` are the entries to
-            first be deleted from current_state_events, `to_insert` are entries
-            to insert.
-            May return None if there are no changes to be applied.
+            Deferred[dict[(str,str), str]|None]:
+                None if there are no changes to the room state, or
+                a dict of (type, state_key) -> event_id].
         """
-        # Now we need to work out the different state sets for
-        # each state extremities
-        state_sets = []
-        missing_event_ids = []
-        was_updated = False
+
+        if not new_latest_event_ids:
+            return
+
+        # map from state_group to ((type, key) -> event_id) state map
+        state_groups_map = {}
+        for ev, ctx in events_context:
+            if ctx.state_group is None:
+                # I don't think this can happen, but let's double-check
+                raise Exception(
+                    "Context for new extremity event %s has no state "
+                    "group" % (ev.event_id, ),
+                )
+
+            if ctx.state_group in state_groups_map:
+                continue
+
+            state_groups_map[ctx.state_group] = ctx.current_state_ids
+
+        # We need to map the event_ids to their state groups. First, let's
+        # check if the event is one we're persisting, in which case we can
+        # pull the state group from its context.
+        # Otherwise we need to pull the state group from the database.
+
+        # Set of events we need to fetch groups for. (We know none of the old
+        # extremities are going to be in events_context).
+        missing_event_ids = set(old_latest_event_ids)
+
+        event_id_to_state_group = {}
         for event_id in new_latest_event_ids:
-            # First search in the list of new events we're adding,
-            # and then use the current state from that
+            # First search in the list of new events we're adding.
             for ev, ctx in events_context:
                 if event_id == ev.event_id:
-                    if ctx.current_state_ids is None:
-                        raise Exception("Unknown current state")
-                    state_sets.append(ctx.current_state_ids)
-                    if ctx.delta_ids or hasattr(ev, "state_key"):
-                        was_updated = True
+                    event_id_to_state_group[event_id] = ctx.state_group
                     break
             else:
                 # If we couldn't find it, then we'll need to pull
                 # the state from the database
-                was_updated = True
-                missing_event_ids.append(event_id)
+                missing_event_ids.add(event_id)
 
         if missing_event_ids:
-            # Now pull out the state for any missing events from DB
+            # Now pull out the state groups for any missing events from DB
             event_to_groups = yield self._get_state_group_for_events(
                 missing_event_ids,
             )
+            event_id_to_state_group.update(event_to_groups)
 
-            groups = set(event_to_groups.values())
-            group_to_state = yield self._get_state_for_groups(groups)
+        # State groups of old_latest_event_ids
+        old_state_groups = set(
+            event_id_to_state_group[evid] for evid in old_latest_event_ids
+        )
 
-            state_sets.extend(group_to_state.values())
+        # State groups of new_latest_event_ids
+        new_state_groups = set(
+            event_id_to_state_group[evid] for evid in new_latest_event_ids
+        )
 
-        if not new_latest_event_ids:
-            current_state = {}
-        elif was_updated:
-            current_state = yield resolve_events(
-                state_sets,
-                state_map_factory=lambda ev_ids: self.get_events(
-                    ev_ids, get_prev_content=False, check_redacted=False,
-                ),
-            )
-        else:
+        # If they old and new groups are the same then we don't need to do
+        # anything.
+        if old_state_groups == new_state_groups:
             return
 
-        existing_state_rows = yield self._simple_select_list(
-            table="current_state_events",
-            keyvalues={"room_id": room_id},
-            retcols=["event_id", "type", "state_key"],
-            desc="_calculate_state_delta",
+        # Now that we have calculated new_state_groups we need to get
+        # their state IDs so we can resolve to a single state set.
+        missing_state = new_state_groups - set(state_groups_map)
+        if missing_state:
+            group_to_state = yield self._get_state_for_groups(missing_state)
+            state_groups_map.update(group_to_state)
+
+        if len(new_state_groups) == 1:
+            # If there is only one state group, then we know what the current
+            # state is.
+            defer.returnValue(state_groups_map[new_state_groups.pop()])
+
+        # Ok, we need to defer to the state handler to resolve our state sets.
+
+        def get_events(ev_ids):
+            return self.get_events(
+                ev_ids, get_prev_content=False, check_redacted=False,
+            )
+
+        state_groups = {
+            sg: state_groups_map[sg] for sg in new_state_groups
+        }
+
+        events_map = {ev.event_id: ev for ev, _ in events_context}
+        logger.debug("calling resolve_state_groups from preserve_events")
+        res = yield self._state_resolution_handler.resolve_state_groups(
+            room_id, state_groups, events_map, get_events
         )
 
-        existing_events = set(row["event_id"] for row in existing_state_rows)
+        defer.returnValue(res.state)
+
+    @defer.inlineCallbacks
+    def _calculate_state_delta(self, room_id, current_state):
+        """Calculate the new state deltas for a room.
+
+        Assumes that we are only persisting events for one room at a time.
+
+        Returns:
+            2-tuple (to_delete, to_insert) where both are state dicts,
+            i.e. (type, state_key) -> event_id. `to_delete` are the entries to
+            first be deleted from current_state_events, `to_insert` are entries
+            to insert.
+        """
+        existing_state = yield self.get_current_state_ids(room_id)
+
+        existing_events = set(existing_state.itervalues())
         new_events = set(ev_id for ev_id in current_state.itervalues())
         changed_events = existing_events ^ new_events
 
@@ -442,9 +652,8 @@ class EventsStore(SQLBaseStore):
             return
 
         to_delete = {
-            (row["type"], row["state_key"]): row["event_id"]
-            for row in existing_state_rows
-            if row["event_id"] in changed_events
+            key: ev_id for key, ev_id in existing_state.iteritems()
+            if ev_id in changed_events
         }
         events_to_insert = (new_events - existing_events)
         to_insert = {
@@ -454,77 +663,104 @@ class EventsStore(SQLBaseStore):
 
         defer.returnValue((to_delete, to_insert))
 
-    @defer.inlineCallbacks
-    def get_event(self, event_id, check_redacted=True,
-                  get_prev_content=False, allow_rejected=False,
-                  allow_none=False):
-        """Get an event from the database by event_id.
+    @log_function
+    def _persist_events_txn(self, txn, events_and_contexts, backfilled,
+                            delete_existing=False, state_delta_for_room={},
+                            new_forward_extremeties={}):
+        """Insert some number of room events into the necessary database tables.
+
+        Rejected events are only inserted into the events table, the events_json table,
+        and the rejections table. Things reading from those table will need to check
+        whether the event was rejected.
 
         Args:
-            event_id (str): The event_id of the event to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
-                include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
-            allow_none (bool): If True, return None if no event found, if
-                False throw an exception.
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]):
+                events to persist
+            backfilled (bool): True if the events were backfilled
+            delete_existing (bool): True to purge existing table rows for the
+                events from the database. This is useful when retrying due to
+                IntegrityError.
+            state_delta_for_room (dict[str, (list[str], list[str])]):
+                The current-state delta for each room. For each room, a tuple
+                (to_delete, to_insert), being a list of event ids to be removed
+                from the current state, and a list of event ids to be added to
+                the current state.
+            new_forward_extremeties (dict[str, list[str]]):
+                The new forward extremities for each room. For each room, a
+                list of the event ids which are the forward extremities.
 
-        Returns:
-            Deferred : A FrozenEvent.
         """
-        events = yield self._get_events(
-            [event_id],
-            check_redacted=check_redacted,
-            get_prev_content=get_prev_content,
-            allow_rejected=allow_rejected,
+        all_events_and_contexts = events_and_contexts
+
+        max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
+
+        self._update_current_state_txn(txn, state_delta_for_room, max_stream_order)
+
+        self._update_forward_extremities_txn(
+            txn,
+            new_forward_extremities=new_forward_extremeties,
+            max_stream_order=max_stream_order,
         )
 
-        if not events and not allow_none:
-            raise SynapseError(404, "Could not find event %s" % (event_id,))
+        # Ensure that we don't have the same event twice.
+        events_and_contexts = self._filter_events_and_contexts_for_duplicates(
+            events_and_contexts,
+        )
 
-        defer.returnValue(events[0] if events else None)
+        self._update_room_depths_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+            backfilled=backfilled,
+        )
 
-    @defer.inlineCallbacks
-    def get_events(self, event_ids, check_redacted=True,
-                   get_prev_content=False, allow_rejected=False):
-        """Get events from the database
+        # _update_outliers_txn filters out any events which have already been
+        # persisted, and returns the filtered list.
+        events_and_contexts = self._update_outliers_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+        )
 
-        Args:
-            event_ids (list): The event_ids of the events to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
-                include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
+        # From this point onwards the events are only events that we haven't
+        # seen before.
 
-        Returns:
-            Deferred : Dict from event_id to event.
-        """
-        events = yield self._get_events(
-            event_ids,
-            check_redacted=check_redacted,
-            get_prev_content=get_prev_content,
-            allow_rejected=allow_rejected,
+        if delete_existing:
+            # For paranoia reasons, we go and delete all the existing entries
+            # for these events so we can reinsert them.
+            # This gets around any problems with some tables already having
+            # entries.
+            self._delete_existing_rows_txn(
+                txn,
+                events_and_contexts=events_and_contexts,
+            )
+
+        self._store_event_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
         )
 
-        defer.returnValue({e.event_id: e for e in events})
+        # Insert into event_to_state_groups.
+        self._store_event_state_mappings_txn(txn, events_and_contexts)
 
-    @log_function
-    def _persist_events_txn(self, txn, events_and_contexts, backfilled,
-                            delete_existing=False, current_state_for_room={},
-                            new_forward_extremeties={}):
-        """Insert some number of room events into the necessary database tables.
+        # _store_rejected_events_txn filters out any events which were
+        # rejected, and returns the filtered list.
+        events_and_contexts = self._store_rejected_events_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+        )
 
-        Rejected events are only inserted into the events table, the events_json table,
-        and the rejections table. Things reading from those table will need to check
-        whether the event was rejected.
+        # From this point onwards the events are only ones that weren't
+        # rejected.
 
-        If delete_existing is True then existing events will be purged from the
-        database before insertion. This is useful when retrying due to IntegrityError.
-        """
-        max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
-        for room_id, current_state_tuple in current_state_for_room.iteritems():
+        self._update_metadata_tables_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+            all_events_and_contexts=all_events_and_contexts,
+            backfilled=backfilled,
+        )
+
+    def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
+        for room_id, current_state_tuple in state_delta_by_room.iteritems():
                 to_delete, to_insert = current_state_tuple
                 txn.executemany(
                     "DELETE FROM current_state_events WHERE event_id = ?",
@@ -545,6 +781,29 @@ class EventsStore(SQLBaseStore):
                     ],
                 )
 
+                state_deltas = {key: None for key in to_delete}
+                state_deltas.update(to_insert)
+
+                self._simple_insert_many_txn(
+                    txn,
+                    table="current_state_delta_stream",
+                    values=[
+                        {
+                            "stream_id": max_stream_order,
+                            "room_id": room_id,
+                            "type": key[0],
+                            "state_key": key[1],
+                            "event_id": ev_id,
+                            "prev_event_id": to_delete.get(key, None),
+                        }
+                        for key, ev_id in state_deltas.iteritems()
+                    ]
+                )
+
+                self._curr_state_delta_stream_cache.entity_has_changed(
+                    room_id, max_stream_order,
+                )
+
                 # Invalidate the various caches
 
                 # Figure out the changes of membership to invalidate the
@@ -553,24 +812,34 @@ class EventsStore(SQLBaseStore):
                 # and which we have added, then we invlidate the caches for all
                 # those users.
                 members_changed = set(
-                    state_key for ev_type, state_key in to_delete.iterkeys()
-                    if ev_type == EventTypes.Member
-                )
-                members_changed.update(
-                    state_key for ev_type, state_key in to_insert.iterkeys()
+                    state_key for ev_type, state_key in state_deltas
                     if ev_type == EventTypes.Member
                 )
 
                 for member in members_changed:
                     self._invalidate_cache_and_stream(
-                        txn, self.get_rooms_for_user, (member,)
+                        txn, self.get_rooms_for_user_with_stream_ordering, (member,)
+                    )
+
+                for host in set(get_domain_from_id(u) for u in members_changed):
+                    self._invalidate_cache_and_stream(
+                        txn, self.is_host_joined, (room_id, host)
+                    )
+                    self._invalidate_cache_and_stream(
+                        txn, self.was_host_joined, (room_id, host)
                     )
 
                 self._invalidate_cache_and_stream(
                     txn, self.get_users_in_room, (room_id,)
                 )
 
-        for room_id, new_extrem in new_forward_extremeties.items():
+                self._invalidate_cache_and_stream(
+                    txn, self.get_current_state_ids, (room_id,)
+                )
+
+    def _update_forward_extremities_txn(self, txn, new_forward_extremities,
+                                        max_stream_order):
+        for room_id, new_extrem in new_forward_extremities.iteritems():
             self._simple_delete_txn(
                 txn,
                 table="event_forward_extremities",
@@ -588,7 +857,7 @@ class EventsStore(SQLBaseStore):
                     "event_id": ev_id,
                     "room_id": room_id,
                 }
-                for room_id, new_extrem in new_forward_extremeties.items()
+                for room_id, new_extrem in new_forward_extremities.iteritems()
                 for ev_id in new_extrem
             ],
         )
@@ -605,13 +874,22 @@ class EventsStore(SQLBaseStore):
                     "event_id": event_id,
                     "stream_ordering": max_stream_order,
                 }
-                for room_id, new_extrem in new_forward_extremeties.items()
+                for room_id, new_extrem in new_forward_extremities.iteritems()
                 for event_id in new_extrem
             ]
         )
 
-        # Ensure that we don't have the same event twice.
-        # Pick the earliest non-outlier if there is one, else the earliest one.
+    @classmethod
+    def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
+        """Ensure that we don't have the same event twice.
+
+        Pick the earliest non-outlier if there is one, else the earliest one.
+
+        Args:
+            events_and_contexts (list[(EventBase, EventContext)]):
+        Returns:
+            list[(EventBase, EventContext)]: filtered list
+        """
         new_events_and_contexts = OrderedDict()
         for event, context in events_and_contexts:
             prev_event_context = new_events_and_contexts.get(event.event_id)
@@ -624,9 +902,17 @@ class EventsStore(SQLBaseStore):
                         new_events_and_contexts[event.event_id] = (event, context)
             else:
                 new_events_and_contexts[event.event_id] = (event, context)
+        return new_events_and_contexts.values()
 
-        events_and_contexts = new_events_and_contexts.values()
+    def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
+        """Update min_depth for each room
 
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+            backfilled (bool): True if the events were backfilled
+        """
         depth_updates = {}
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
@@ -642,9 +928,24 @@ class EventsStore(SQLBaseStore):
                     event.depth, depth_updates.get(event.room_id, event.depth)
                 )
 
-        for room_id, depth in depth_updates.items():
+        for room_id, depth in depth_updates.iteritems():
             self._update_min_depth_for_room_txn(txn, room_id, depth)
 
+    def _update_outliers_txn(self, txn, events_and_contexts):
+        """Update any outliers with new event info.
+
+        This turns outliers into ex-outliers (unless the new event was
+        rejected).
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+
+        Returns:
+            list[(EventBase, EventContext)] new list, without events which
+            are already in the events table.
+        """
         txn.execute(
             "SELECT event_id, outlier FROM events WHERE event_id in (%s)" % (
                 ",".join(["?"] * len(events_and_contexts)),
@@ -654,34 +955,30 @@ class EventsStore(SQLBaseStore):
 
         have_persisted = {
             event_id: outlier
-            for event_id, outlier in txn.fetchall()
+            for event_id, outlier in txn
         }
 
         to_remove = set()
         for event, context in events_and_contexts:
-            if context.rejected:
-                # If the event is rejected then we don't care if the event
-                # was an outlier or not.
-                if event.event_id in have_persisted:
-                    # If we have already seen the event then ignore it.
-                    to_remove.add(event)
-                continue
-
             if event.event_id not in have_persisted:
                 continue
 
             to_remove.add(event)
 
+            if context.rejected:
+                # If the event is rejected then we don't care if the event
+                # was an outlier or not.
+                continue
+
             outlier_persisted = have_persisted[event.event_id]
             if not event.internal_metadata.is_outlier() and outlier_persisted:
                 # We received a copy of an event that we had already stored as
                 # an outlier in the database. We now have some state at that
                 # so we need to update the state_groups table with that state.
 
-                # insert into the state_group, state_groups_state and
-                # event_to_state_groups tables.
+                # insert into event_to_state_groups.
                 try:
-                    self._store_mult_state_groups_txn(txn, ((event, context),))
+                    self._store_event_state_mappings_txn(txn, ((event, context),))
                 except Exception:
                     logger.exception("")
                     raise
@@ -726,37 +1023,19 @@ class EventsStore(SQLBaseStore):
                 # event isn't an outlier any more.
                 self._update_backward_extremeties(txn, [event])
 
-        events_and_contexts = [
+        return [
             ec for ec in events_and_contexts if ec[0] not in to_remove
         ]
 
+    @classmethod
+    def _delete_existing_rows_txn(cls, txn, events_and_contexts):
         if not events_and_contexts:
-            # Make sure we don't pass an empty list to functions that expect to
-            # be storing at least one element.
+            # nothing to do here
             return
 
-        # From this point onwards the events are only events that we haven't
-        # seen before.
-
-        def event_dict(event):
-            return {
-                k: v
-                for k, v in event.get_dict().items()
-                if k not in [
-                    "redacted",
-                    "redacted_because",
-                ]
-            }
-
-        if delete_existing:
-            # For paranoia reasons, we go and delete all the existing entries
-            # for these events so we can reinsert them.
-            # This gets around any problems with some tables already having
-            # entries.
-
-            logger.info("Deleting existing")
+        logger.info("Deleting existing")
 
-            for table in (
+        for table in (
                 "events",
                 "event_auth",
                 "event_json",
@@ -779,11 +1058,30 @@ class EventsStore(SQLBaseStore):
                 "redactions",
                 "room_memberships",
                 "topics"
-            ):
-                txn.executemany(
-                    "DELETE FROM %s WHERE event_id = ?" % (table,),
-                    [(ev.event_id,) for ev, _ in events_and_contexts]
-                )
+        ):
+            txn.executemany(
+                "DELETE FROM %s WHERE event_id = ?" % (table,),
+                [(ev.event_id,) for ev, _ in events_and_contexts]
+            )
+
+    def _store_event_txn(self, txn, events_and_contexts):
+        """Insert new events into the event and event_json tables
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+        """
+
+        if not events_and_contexts:
+            # nothing to do here
+            return
+
+        def event_dict(event):
+            d = event.get_dict()
+            d.pop("redacted", None)
+            d.pop("redacted_because", None)
+            return d
 
         self._simple_insert_many_txn(
             txn,
@@ -827,6 +1125,19 @@ class EventsStore(SQLBaseStore):
             ],
         )
 
+    def _store_rejected_events_txn(self, txn, events_and_contexts):
+        """Add rows to the 'rejections' table for received events which were
+        rejected
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+
+        Returns:
+            list[(EventBase, EventContext)] new list, without the rejected
+                events.
+        """
         # Remove the rejected events from the list now that we've added them
         # to the events table and the events_json table.
         to_remove = set()
@@ -838,24 +1149,37 @@ class EventsStore(SQLBaseStore):
                 )
                 to_remove.add(event)
 
-        events_and_contexts = [
+        return [
             ec for ec in events_and_contexts if ec[0] not in to_remove
         ]
 
+    def _update_metadata_tables_txn(self, txn, events_and_contexts,
+                                    all_events_and_contexts, backfilled):
+        """Update all the miscellaneous tables for new events
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            events_and_contexts (list[(EventBase, EventContext)]): events
+                we are persisting
+            all_events_and_contexts (list[(EventBase, EventContext)]): all
+                events that we were going to persist. This includes events
+                we've already persisted, etc, that wouldn't appear in
+                events_and_context.
+            backfilled (bool): True if the events were backfilled
+        """
+
+        # Insert all the push actions into the event_push_actions table.
+        self._set_push_actions_for_event_and_users_txn(
+            txn,
+            events_and_contexts=events_and_contexts,
+            all_events_and_contexts=all_events_and_contexts,
+        )
+
         if not events_and_contexts:
-            # Make sure we don't pass an empty list to functions that expect to
-            # be storing at least one element.
+            # nothing to do here
             return
 
-        # From this point onwards the events are only ones that weren't rejected.
-
         for event, context in events_and_contexts:
-            # Insert all the push actions into the event_push_actions table.
-            if context.push_actions:
-                self._set_push_actions_for_event_and_users_txn(
-                    txn, event, context.push_actions
-                )
-
             if event.type == EventTypes.Redaction and event.redacts is not None:
                 # Remove the entries in the event_push_actions table for the
                 # redacted event.
@@ -874,13 +1198,10 @@ class EventsStore(SQLBaseStore):
                 }
                 for event, _ in events_and_contexts
                 for auth_id, _ in event.auth_events
+                if event.is_state()
             ],
         )
 
-        # Insert into the state_groups, state_groups_state, and
-        # event_to_state_groups tables.
-        self._store_mult_state_groups_txn(txn, events_and_contexts)
-
         # Update the event_forward_extremities, event_backward_extremities and
         # event_edges tables.
         self._handle_mult_prev_events(
@@ -967,13 +1288,6 @@ class EventsStore(SQLBaseStore):
         # Prefill the event cache
         self._add_to_cache(txn, events_and_contexts)
 
-        if backfilled:
-            # Backfilled events come before the current state so we don't need
-            # to update the current state table
-            return
-
-        return
-
     def _add_to_cache(self, txn, events_and_contexts):
         to_prefill = []
 
@@ -1037,13 +1351,49 @@ class EventsStore(SQLBaseStore):
 
         defer.returnValue(set(r["event_id"] for r in rows))
 
-    def have_events(self, event_ids):
+    @defer.inlineCallbacks
+    def have_seen_events(self, event_ids):
         """Given a list of event ids, check if we have already processed them.
 
+        Args:
+            event_ids (iterable[str]):
+
         Returns:
-            dict: Has an entry for each event id we already have seen. Maps to
-            the rejected reason string if we rejected the event, else maps to
-            None.
+            Deferred[set[str]]: The events we have already seen.
+        """
+        results = set()
+
+        def have_seen_events_txn(txn, chunk):
+            sql = (
+                "SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
+                % (",".join("?" * len(chunk)), )
+            )
+            txn.execute(sql, chunk)
+            for (event_id, ) in txn:
+                results.add(event_id)
+
+        # break the input up into chunks of 100
+        input_iterator = iter(event_ids)
+        for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
+                          []):
+            yield self.runInteraction(
+                "have_seen_events",
+                have_seen_events_txn,
+                chunk,
+            )
+        defer.returnValue(results)
+
+    def get_seen_events_with_rejections(self, event_ids):
+        """Given a list of event ids, check if we rejected them.
+
+        Args:
+            event_ids (list[str])
+
+        Returns:
+            Deferred[dict[str, str|None):
+                Has an entry for each event id we already have seen. Maps to
+                the rejected reason string if we rejected the event, else maps
+                to None.
         """
         if not event_ids:
             return defer.succeed({})
@@ -1065,280 +1415,7 @@ class EventsStore(SQLBaseStore):
 
             return res
 
-        return self.runInteraction(
-            "have_events", f,
-        )
-
-    @defer.inlineCallbacks
-    def _get_events(self, event_ids, check_redacted=True,
-                    get_prev_content=False, allow_rejected=False):
-        if not event_ids:
-            defer.returnValue([])
-
-        event_id_list = event_ids
-        event_ids = set(event_ids)
-
-        event_entry_map = self._get_events_from_cache(
-            event_ids,
-            allow_rejected=allow_rejected,
-        )
-
-        missing_events_ids = [e for e in event_ids if e not in event_entry_map]
-
-        if missing_events_ids:
-            missing_events = yield self._enqueue_events(
-                missing_events_ids,
-                check_redacted=check_redacted,
-                allow_rejected=allow_rejected,
-            )
-
-            event_entry_map.update(missing_events)
-
-        events = []
-        for event_id in event_id_list:
-            entry = event_entry_map.get(event_id, None)
-            if not entry:
-                continue
-
-            if allow_rejected or not entry.event.rejected_reason:
-                if check_redacted and entry.redacted_event:
-                    event = entry.redacted_event
-                else:
-                    event = entry.event
-
-                events.append(event)
-
-                if get_prev_content:
-                    if "replaces_state" in event.unsigned:
-                        prev = yield self.get_event(
-                            event.unsigned["replaces_state"],
-                            get_prev_content=False,
-                            allow_none=True,
-                        )
-                        if prev:
-                            event.unsigned = dict(event.unsigned)
-                            event.unsigned["prev_content"] = prev.content
-                            event.unsigned["prev_sender"] = prev.sender
-
-        defer.returnValue(events)
-
-    def _invalidate_get_event_cache(self, event_id):
-            self._get_event_cache.invalidate((event_id,))
-
-    def _get_events_from_cache(self, events, allow_rejected):
-        event_map = {}
-
-        for event_id in events:
-            ret = self._get_event_cache.get((event_id,), None)
-            if not ret:
-                continue
-
-            if allow_rejected or not ret.event.rejected_reason:
-                event_map[event_id] = ret
-            else:
-                event_map[event_id] = None
-
-        return event_map
-
-    def _do_fetch(self, conn):
-        """Takes a database connection and waits for requests for events from
-        the _event_fetch_list queue.
-        """
-        event_list = []
-        i = 0
-        while True:
-            try:
-                with self._event_fetch_lock:
-                    event_list = self._event_fetch_list
-                    self._event_fetch_list = []
-
-                    if not event_list:
-                        single_threaded = self.database_engine.single_threaded
-                        if single_threaded or i > EVENT_QUEUE_ITERATIONS:
-                            self._event_fetch_ongoing -= 1
-                            return
-                        else:
-                            self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
-                            i += 1
-                            continue
-                    i = 0
-
-                event_id_lists = zip(*event_list)[0]
-                event_ids = [
-                    item for sublist in event_id_lists for item in sublist
-                ]
-
-                rows = self._new_transaction(
-                    conn, "do_fetch", [], None, self._fetch_event_rows, event_ids
-                )
-
-                row_dict = {
-                    r["event_id"]: r
-                    for r in rows
-                }
-
-                # We only want to resolve deferreds from the main thread
-                def fire(lst, res):
-                    for ids, d in lst:
-                        if not d.called:
-                            try:
-                                with PreserveLoggingContext():
-                                    d.callback([
-                                        res[i]
-                                        for i in ids
-                                        if i in res
-                                    ])
-                            except:
-                                logger.exception("Failed to callback")
-                with PreserveLoggingContext():
-                    reactor.callFromThread(fire, event_list, row_dict)
-            except Exception as e:
-                logger.exception("do_fetch")
-
-                # We only want to resolve deferreds from the main thread
-                def fire(evs):
-                    for _, d in evs:
-                        if not d.called:
-                            with PreserveLoggingContext():
-                                d.errback(e)
-
-                if event_list:
-                    with PreserveLoggingContext():
-                        reactor.callFromThread(fire, event_list)
-
-    @defer.inlineCallbacks
-    def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
-        """Fetches events from the database using the _event_fetch_list. This
-        allows batch and bulk fetching of events - it allows us to fetch events
-        without having to create a new transaction for each request for events.
-        """
-        if not events:
-            defer.returnValue({})
-
-        events_d = defer.Deferred()
-        with self._event_fetch_lock:
-            self._event_fetch_list.append(
-                (events, events_d)
-            )
-
-            self._event_fetch_lock.notify()
-
-            if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
-                self._event_fetch_ongoing += 1
-                should_start = True
-            else:
-                should_start = False
-
-        if should_start:
-            with PreserveLoggingContext():
-                self.runWithConnection(
-                    self._do_fetch
-                )
-
-        logger.debug("Loading %d events", len(events))
-        with PreserveLoggingContext():
-            rows = yield events_d
-        logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
-
-        if not allow_rejected:
-            rows[:] = [r for r in rows if not r["rejects"]]
-
-        res = yield preserve_context_over_deferred(defer.gatherResults(
-            [
-                preserve_fn(self._get_event_from_row)(
-                    row["internal_metadata"], row["json"], row["redacts"],
-                    rejected_reason=row["rejects"],
-                )
-                for row in rows
-            ],
-            consumeErrors=True
-        ))
-
-        defer.returnValue({
-            e.event.event_id: e
-            for e in res if e
-        })
-
-    def _fetch_event_rows(self, txn, events):
-        rows = []
-        N = 200
-        for i in range(1 + len(events) / N):
-            evs = events[i * N:(i + 1) * N]
-            if not evs:
-                break
-
-            sql = (
-                "SELECT "
-                " e.event_id as event_id, "
-                " e.internal_metadata,"
-                " e.json,"
-                " r.redacts as redacts,"
-                " rej.event_id as rejects "
-                " FROM event_json as e"
-                " LEFT JOIN rejections as rej USING (event_id)"
-                " LEFT JOIN redactions as r ON e.event_id = r.redacts"
-                " WHERE e.event_id IN (%s)"
-            ) % (",".join(["?"] * len(evs)),)
-
-            txn.execute(sql, evs)
-            rows.extend(self.cursor_to_dict(txn))
-
-        return rows
-
-    @defer.inlineCallbacks
-    def _get_event_from_row(self, internal_metadata, js, redacted,
-                            rejected_reason=None):
-        with Measure(self._clock, "_get_event_from_row"):
-            d = json.loads(js)
-            internal_metadata = json.loads(internal_metadata)
-
-            if rejected_reason:
-                rejected_reason = yield self._simple_select_one_onecol(
-                    table="rejections",
-                    keyvalues={"event_id": rejected_reason},
-                    retcol="reason",
-                    desc="_get_event_from_row_rejected_reason",
-                )
-
-            original_ev = FrozenEvent(
-                d,
-                internal_metadata_dict=internal_metadata,
-                rejected_reason=rejected_reason,
-            )
-
-            redacted_event = None
-            if redacted:
-                redacted_event = prune_event(original_ev)
-
-                redaction_id = yield self._simple_select_one_onecol(
-                    table="redactions",
-                    keyvalues={"redacts": redacted_event.event_id},
-                    retcol="event_id",
-                    desc="_get_event_from_row_redactions",
-                )
-
-                redacted_event.unsigned["redacted_by"] = redaction_id
-                # Get the redaction event.
-
-                because = yield self.get_event(
-                    redaction_id,
-                    check_redacted=False,
-                    allow_none=True,
-                )
-
-                if because:
-                    # It's fine to do add the event directly, since get_pdu_json
-                    # will serialise this field correctly
-                    redacted_event.unsigned["redacted_because"] = because
-
-            cache_entry = _EventCacheEntry(
-                event=original_ev,
-                redacted_event=redacted_event,
-            )
-
-            self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
-
-        defer.returnValue(cache_entry)
+        return self.runInteraction("get_rejection_reasons", f)
 
     @defer.inlineCallbacks
     def count_daily_messages(self):
@@ -1349,66 +1426,52 @@ class EventsStore(SQLBaseStore):
         call to this function, it will return None.
         """
         def _count_messages(txn):
-            now = self.hs.get_clock().time()
-
-            txn.execute(
-                "SELECT reported_stream_token, reported_time FROM stats_reporting"
-            )
-            last_reported = self.cursor_to_dict(txn)
+            sql = """
+                SELECT COALESCE(COUNT(*), 0) FROM events
+                WHERE type = 'm.room.message'
+                AND stream_ordering > ?
+            """
+            txn.execute(sql, (self.stream_ordering_day_ago,))
+            count, = txn.fetchone()
+            return count
 
-            txn.execute(
-                "SELECT stream_ordering"
-                " FROM events"
-                " ORDER BY stream_ordering DESC"
-                " LIMIT 1"
-            )
-            now_reporting = self.cursor_to_dict(txn)
-            if not now_reporting:
-                logger.info("Calculating daily messages skipped; no now_reporting")
-                return None
-            now_reporting = now_reporting[0]["stream_ordering"]
-
-            txn.execute("DELETE FROM stats_reporting")
-            txn.execute(
-                "INSERT INTO stats_reporting"
-                " (reported_stream_token, reported_time)"
-                " VALUES (?, ?)",
-                (now_reporting, now,)
-            )
-
-            if not last_reported:
-                logger.info("Calculating daily messages skipped; no last_reported")
-                return None
-
-            # Close enough to correct for our purposes.
-            yesterday = (now - 24 * 60 * 60)
-            since_yesterday_seconds = yesterday - last_reported[0]["reported_time"]
-            any_since_yesterday = math.fabs(since_yesterday_seconds) > 60 * 60
-            if any_since_yesterday:
-                logger.info(
-                    "Calculating daily messages skipped; since_yesterday_seconds: %d" %
-                    (since_yesterday_seconds,)
-                )
-                return None
+        ret = yield self.runInteraction("count_messages", _count_messages)
+        defer.returnValue(ret)
 
-            txn.execute(
-                "SELECT COUNT(*) as messages"
-                " FROM events NATURAL JOIN event_json"
-                " WHERE json like '%m.room.message%'"
-                " AND stream_ordering > ?"
-                " AND stream_ordering <= ?",
-                (
-                    last_reported[0]["reported_stream_token"],
-                    now_reporting,
-                )
-            )
-            rows = self.cursor_to_dict(txn)
-            if not rows:
-                logger.info("Calculating daily messages skipped; messages count missing")
-                return None
-            return rows[0]["messages"]
+    @defer.inlineCallbacks
+    def count_daily_sent_messages(self):
+        def _count_messages(txn):
+            # This is good enough as if you have silly characters in your own
+            # hostname then thats your own fault.
+            like_clause = "%:" + self.hs.hostname
+
+            sql = """
+                SELECT COALESCE(COUNT(*), 0) FROM events
+                WHERE type = 'm.room.message'
+                    AND sender LIKE ?
+                AND stream_ordering > ?
+            """
+
+            txn.execute(sql, (like_clause, self.stream_ordering_day_ago,))
+            count, = txn.fetchone()
+            return count
+
+        ret = yield self.runInteraction("count_daily_sent_messages", _count_messages)
+        defer.returnValue(ret)
 
-        ret = yield self.runInteraction("count_messages", _count_messages)
+    @defer.inlineCallbacks
+    def count_daily_active_rooms(self):
+        def _count(txn):
+            sql = """
+                SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+                WHERE type = 'm.room.message'
+                AND stream_ordering > ?
+            """
+            txn.execute(sql, (self.stream_ordering_day_ago,))
+            count, = txn.fetchone()
+            return count
+
+        ret = yield self.runInteraction("count_daily_active_rooms", _count)
         defer.returnValue(ret)
 
     @defer.inlineCallbacks
@@ -1569,6 +1632,94 @@ class EventsStore(SQLBaseStore):
         """The current minimum token that backfilled events have reached"""
         return -self._backfill_id_gen.get_current_token()
 
+    def get_current_events_token(self):
+        """The current maximum token that events have reached"""
+        return self._stream_id_gen.get_current_token()
+
+    def get_all_new_forward_event_rows(self, last_id, current_id, limit):
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_new_forward_event_rows(txn):
+            sql = (
+                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? < stream_ordering AND stream_ordering <= ?"
+                " ORDER BY stream_ordering ASC"
+                " LIMIT ?"
+            )
+            txn.execute(sql, (last_id, current_id, limit))
+            new_event_updates = txn.fetchall()
+
+            if len(new_event_updates) == limit:
+                upper_bound = new_event_updates[-1][0]
+            else:
+                upper_bound = current_id
+
+            sql = (
+                "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " INNER JOIN ex_outlier_stream USING (event_id)"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? < event_stream_ordering"
+                " AND event_stream_ordering <= ?"
+                " ORDER BY event_stream_ordering DESC"
+            )
+            txn.execute(sql, (last_id, upper_bound))
+            new_event_updates.extend(txn)
+
+            return new_event_updates
+        return self.runInteraction(
+            "get_all_new_forward_event_rows", get_all_new_forward_event_rows
+        )
+
+    def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_new_backfill_event_rows(txn):
+            sql = (
+                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? > stream_ordering AND stream_ordering >= ?"
+                " ORDER BY stream_ordering ASC"
+                " LIMIT ?"
+            )
+            txn.execute(sql, (-last_id, -current_id, limit))
+            new_event_updates = txn.fetchall()
+
+            if len(new_event_updates) == limit:
+                upper_bound = new_event_updates[-1][0]
+            else:
+                upper_bound = current_id
+
+            sql = (
+                "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " INNER JOIN ex_outlier_stream USING (event_id)"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? > event_stream_ordering"
+                " AND event_stream_ordering >= ?"
+                " ORDER BY event_stream_ordering DESC"
+            )
+            txn.execute(sql, (-last_id, -upper_bound))
+            new_event_updates.extend(txn.fetchall())
+
+            return new_event_updates
+        return self.runInteraction(
+            "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
+        )
+
     @cached(num_args=5, max_entries=10)
     def get_all_new_events(self, last_backfill_id, last_forward_id,
                            current_backfill_id, current_forward_id, limit):
@@ -1582,14 +1733,13 @@ class EventsStore(SQLBaseStore):
 
         def get_all_new_events_txn(txn):
             sql = (
-                "SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group"
-                " FROM events as e"
-                " JOIN event_json as ej"
-                " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
-                " LEFT JOIN event_to_state_groups as eg"
-                " ON e.event_id = eg.event_id"
-                " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?"
-                " ORDER BY e.stream_ordering ASC"
+                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? < stream_ordering AND stream_ordering <= ?"
+                " ORDER BY stream_ordering ASC"
                 " LIMIT ?"
             )
             if have_forward_events:
@@ -1615,15 +1765,13 @@ class EventsStore(SQLBaseStore):
                 forward_ex_outliers = []
 
             sql = (
-                "SELECT -e.stream_ordering, ej.internal_metadata, ej.json,"
-                " eg.state_group"
-                " FROM events as e"
-                " JOIN event_json as ej"
-                " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
-                " LEFT JOIN event_to_state_groups as eg"
-                " ON e.event_id = eg.event_id"
-                " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?"
-                " ORDER BY e.stream_ordering DESC"
+                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " WHERE ? > stream_ordering AND stream_ordering >= ?"
+                " ORDER BY stream_ordering DESC"
                 " LIMIT ?"
             )
             if have_backfill_events:
@@ -1654,16 +1802,32 @@ class EventsStore(SQLBaseStore):
             )
         return self.runInteraction("get_all_new_events", get_all_new_events_txn)
 
-    def delete_old_state(self, room_id, topological_ordering):
-        return self.runInteraction(
-            "delete_old_state",
-            self._delete_old_state_txn, room_id, topological_ordering
-        )
+    def purge_history(
+        self, room_id, topological_ordering, delete_local_events,
+    ):
+        """Deletes room history before a certain point
+
+        Args:
+            room_id (str):
 
-    def _delete_old_state_txn(self, txn, room_id, topological_ordering):
-        """Deletes old room state
+            topological_ordering (int):
+                minimum topo ordering to preserve
+
+            delete_local_events (bool):
+                if True, we will delete local events as well as remote ones
+                (instead of just marking them as outliers and deleting their
+                state groups).
         """
 
+        return self.runInteraction(
+            "purge_history",
+            self._purge_history_txn, room_id, topological_ordering,
+            delete_local_events,
+        )
+
+    def _purge_history_txn(
+        self, txn, room_id, topological_ordering, delete_local_events,
+    ):
         # Tables that should be pruned:
         #     event_auth
         #     event_backward_extremities
@@ -1684,6 +1848,30 @@ class EventsStore(SQLBaseStore):
         #     state_groups
         #     state_groups_state
 
+        # we will build a temporary table listing the events so that we don't
+        # have to keep shovelling the list back and forth across the
+        # connection. Annoyingly the python sqlite driver commits the
+        # transaction on CREATE, so let's do this first.
+        #
+        # furthermore, we might already have the table from a previous (failed)
+        # purge attempt, so let's drop the table first.
+
+        txn.execute("DROP TABLE IF EXISTS events_to_purge")
+
+        txn.execute(
+            "CREATE TEMPORARY TABLE events_to_purge ("
+            "    event_id TEXT NOT NULL,"
+            "    should_delete BOOLEAN NOT NULL"
+            ")"
+        )
+
+        # create an index on should_delete because later we'll be looking for
+        # the should_delete / shouldn't_delete subsets
+        txn.execute(
+            "CREATE INDEX events_to_purge_should_delete"
+            " ON events_to_purge(should_delete)",
+        )
+
         # First ensure that we're not about to delete all the forward extremeties
         txn.execute(
             "SELECT e.event_id, e.depth FROM events as e "
@@ -1704,29 +1892,49 @@ class EventsStore(SQLBaseStore):
                 400, "topological_ordering is greater than forward extremeties"
             )
 
+        logger.info("[purge] looking for events to delete")
+
+        should_delete_expr = "state_key IS NULL"
+        should_delete_params = ()
+        if not delete_local_events:
+            should_delete_expr += " AND event_id NOT LIKE ?"
+            should_delete_params += ("%:" + self.hs.hostname, )
+
+        should_delete_params += (room_id, topological_ordering)
+
+        txn.execute(
+            "INSERT INTO events_to_purge"
+            " SELECT event_id, %s"
+            " FROM events AS e LEFT JOIN state_events USING (event_id)"
+            " WHERE e.room_id = ? AND topological_ordering < ?" % (
+                should_delete_expr,
+            ),
+            should_delete_params,
+        )
         txn.execute(
-            "SELECT event_id, state_key FROM events"
-            " LEFT JOIN state_events USING (room_id, event_id)"
-            " WHERE room_id = ? AND topological_ordering < ?",
-            (room_id, topological_ordering,)
+            "SELECT event_id, should_delete FROM events_to_purge"
         )
         event_rows = txn.fetchall()
+        logger.info(
+            "[purge] found %i events before cutoff, of which %i can be deleted",
+            len(event_rows), sum(1 for e in event_rows if e[1]),
+        )
 
-        for event_id, state_key in event_rows:
-            txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
+        logger.info("[purge] Finding new backward extremities")
 
         # We calculate the new entries for the backward extremeties by finding
         # all events that point to events that are to be purged
         txn.execute(
-            "SELECT DISTINCT e.event_id FROM events as e"
-            " INNER JOIN event_edges as ed ON e.event_id = ed.prev_event_id"
-            " INNER JOIN events as e2 ON e2.event_id = ed.event_id"
-            " WHERE e.room_id = ? AND e.topological_ordering < ?"
-            " AND e2.topological_ordering >= ?",
-            (room_id, topological_ordering, topological_ordering)
+            "SELECT DISTINCT e.event_id FROM events_to_purge AS e"
+            " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
+            " INNER JOIN events AS e2 ON e2.event_id = ed.event_id"
+            " WHERE e2.topological_ordering >= ?",
+            (topological_ordering, )
         )
         new_backwards_extrems = txn.fetchall()
 
+        logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
+
         txn.execute(
             "DELETE FROM event_backward_extremities WHERE room_id = ?",
             (room_id,)
@@ -1741,30 +1949,36 @@ class EventsStore(SQLBaseStore):
             ]
         )
 
+        logger.info("[purge] finding redundant state groups")
+
         # Get all state groups that are only referenced by events that are
         # to be deleted.
         txn.execute(
             "SELECT state_group FROM event_to_state_groups"
             " INNER JOIN events USING (event_id)"
             " WHERE state_group IN ("
-            "   SELECT DISTINCT state_group FROM events"
+            "   SELECT DISTINCT state_group FROM events_to_purge"
             "   INNER JOIN event_to_state_groups USING (event_id)"
-            "   WHERE room_id = ? AND topological_ordering < ?"
             " )"
             " GROUP BY state_group HAVING MAX(topological_ordering) < ?",
-            (room_id, topological_ordering, topological_ordering)
+            (topological_ordering, )
         )
 
         state_rows = txn.fetchall()
-        state_groups_to_delete = [sg for sg, in state_rows]
+        logger.info("[purge] found %i redundant state groups", len(state_rows))
+
+        # make a set of the redundant state groups, so that we can look them up
+        # efficiently
+        state_groups_to_delete = set([sg for sg, in state_rows])
 
         # Now we get all the state groups that rely on these state groups
-        new_state_edges = []
-        chunks = [
-            state_groups_to_delete[i:i + 100]
-            for i in xrange(0, len(state_groups_to_delete), 100)
-        ]
-        for chunk in chunks:
+        logger.info("[purge] finding state groups which depend on redundant"
+                    " state groups")
+        remaining_state_groups = []
+        for i in xrange(0, len(state_rows), 100):
+            chunk = [sg for sg, in state_rows[i:i + 100]]
+            # look for state groups whose prev_state_group is one we are about
+            # to delete
             rows = self._simple_select_many_txn(
                 txn,
                 table="state_group_edges",
@@ -1773,21 +1987,28 @@ class EventsStore(SQLBaseStore):
                 retcols=["state_group"],
                 keyvalues={},
             )
-            new_state_edges.extend(row["state_group"] for row in rows)
+            remaining_state_groups.extend(
+                row["state_group"] for row in rows
+
+                # exclude state groups we are about to delete: no point in
+                # updating them
+                if row["state_group"] not in state_groups_to_delete
+            )
 
-        # Now we turn the state groups that reference to-be-deleted state groups
-        # to non delta versions.
-        for new_state_edge in new_state_edges:
+        # Now we turn the state groups that reference to-be-deleted state
+        # groups to non delta versions.
+        for sg in remaining_state_groups:
+            logger.info("[purge] de-delta-ing remaining state group %s", sg)
             curr_state = self._get_state_groups_from_groups_txn(
-                txn, [new_state_edge], types=None
+                txn, [sg], types=None
             )
-            curr_state = curr_state[new_state_edge]
+            curr_state = curr_state[sg]
 
             self._simple_delete_txn(
                 txn,
                 table="state_groups_state",
                 keyvalues={
-                    "state_group": new_state_edge,
+                    "state_group": sg,
                 }
             )
 
@@ -1795,7 +2016,7 @@ class EventsStore(SQLBaseStore):
                 txn,
                 table="state_group_edges",
                 keyvalues={
-                    "state_group": new_state_edge,
+                    "state_group": sg,
                 }
             )
 
@@ -1804,16 +2025,17 @@ class EventsStore(SQLBaseStore):
                 table="state_groups_state",
                 values=[
                     {
-                        "state_group": new_state_edge,
+                        "state_group": sg,
                         "room_id": room_id,
                         "type": key[0],
                         "state_key": key[1],
                         "event_id": state_id,
                     }
-                    for key, state_id in curr_state.items()
+                    for key, state_id in curr_state.iteritems()
                 ],
             )
 
+        logger.info("[purge] removing redundant state groups")
         txn.executemany(
             "DELETE FROM state_groups_state WHERE state_group = ?",
             state_rows
@@ -1822,22 +2044,18 @@ class EventsStore(SQLBaseStore):
             "DELETE FROM state_groups WHERE id = ?",
             state_rows
         )
-        # Delete all non-state
-        txn.executemany(
-            "DELETE FROM event_to_state_groups WHERE event_id = ?",
-            [(event_id,) for event_id, _ in event_rows]
-        )
 
+        logger.info("[purge] removing events from event_to_state_groups")
         txn.execute(
-            "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
-            (topological_ordering, room_id,)
+            "DELETE FROM event_to_state_groups "
+            "WHERE event_id IN (SELECT event_id from events_to_purge)"
         )
+        for event_id, _ in event_rows:
+            txn.call_after(self._get_state_group_for_event.invalidate, (
+                event_id,
+            ))
 
         # Delete all remote non-state events
-        to_delete = [
-            (event_id,) for event_id, state_key in event_rows
-            if state_key is None and not self.hs.is_mine_id(event_id)
-        ]
         for table in (
             "events",
             "event_json",
@@ -1847,29 +2065,102 @@ class EventsStore(SQLBaseStore):
             "event_edge_hashes",
             "event_edges",
             "event_forward_extremities",
-            "event_push_actions",
             "event_reference_hashes",
             "event_search",
             "event_signatures",
             "rejections",
         ):
-            txn.executemany(
-                "DELETE FROM %s WHERE event_id = ?" % (table,),
-                to_delete
+            logger.info("[purge] removing events from %s", table)
+
+            txn.execute(
+                "DELETE FROM %s WHERE event_id IN ("
+                "    SELECT event_id FROM events_to_purge WHERE should_delete"
+                ")" % (table,),
+            )
+
+        # event_push_actions lacks an index on event_id, and has one on
+        # (room_id, event_id) instead.
+        for table in (
+            "event_push_actions",
+        ):
+            logger.info("[purge] removing events from %s", table)
+
+            txn.execute(
+                "DELETE FROM %s WHERE room_id = ? AND event_id IN ("
+                "    SELECT event_id FROM events_to_purge WHERE should_delete"
+                ")" % (table,),
+                (room_id, )
             )
 
-        txn.executemany(
-            "DELETE FROM events WHERE event_id = ?",
-            to_delete
-        )
         # Mark all state and own events as outliers
-        txn.executemany(
+        logger.info("[purge] marking remaining events as outliers")
+        txn.execute(
             "UPDATE events SET outlier = ?"
-            " WHERE event_id = ?",
-            [
-                (True, event_id,) for event_id, state_key in event_rows
-                if state_key is not None or self.hs.is_mine_id(event_id)
-            ]
+            " WHERE event_id IN ("
+            "    SELECT event_id FROM events_to_purge "
+            "    WHERE NOT should_delete"
+            ")",
+            (True,),
+        )
+
+        # synapse tries to take out an exclusive lock on room_depth whenever it
+        # persists events (because upsert), and once we run this update, we
+        # will block that for the rest of our transaction.
+        #
+        # So, let's stick it at the end so that we don't block event
+        # persistence.
+        logger.info("[purge] updating room_depth")
+        txn.execute(
+            "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
+            (topological_ordering, room_id,)
+        )
+
+        # finally, drop the temp table. this will commit the txn in sqlite,
+        # so make sure to keep this actually last.
+        txn.execute(
+            "DROP TABLE events_to_purge"
+        )
+
+        logger.info("[purge] done")
+
+    @defer.inlineCallbacks
+    def is_event_after(self, event_id1, event_id2):
+        """Returns True if event_id1 is after event_id2 in the stream
+        """
+        to_1, so_1 = yield self._get_event_ordering(event_id1)
+        to_2, so_2 = yield self._get_event_ordering(event_id2)
+        defer.returnValue((to_1, so_1) > (to_2, so_2))
+
+    @cachedInlineCallbacks(max_entries=5000)
+    def _get_event_ordering(self, event_id):
+        res = yield self._simple_select_one(
+            table="events",
+            retcols=["topological_ordering", "stream_ordering"],
+            keyvalues={"event_id": event_id},
+            allow_none=True
+        )
+
+        if not res:
+            raise SynapseError(404, "Could not find event %s" % (event_id,))
+
+        defer.returnValue((int(res["topological_ordering"]), int(res["stream_ordering"])))
+
+    def get_max_current_state_delta_stream_id(self):
+        return self._stream_id_gen.get_current_token()
+
+    def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
+        def get_all_updated_current_state_deltas_txn(txn):
+            sql = """
+                SELECT stream_id, room_id, type, state_key, event_id
+                FROM current_state_delta_stream
+                WHERE ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC LIMIT ?
+            """
+            txn.execute(sql, (from_token, to_token, limit))
+            return txn.fetchall()
+        return self.runInteraction(
+            "get_all_updated_current_state_deltas",
+            get_all_updated_current_state_deltas_txn,
         )
 
 
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
new file mode 100644
index 0000000000..ba834854e1
--- /dev/null
+++ b/synapse/storage/events_worker.py
@@ -0,0 +1,416 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore
+
+from twisted.internet import defer, reactor
+
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
+
+from synapse.util.logcontext import (
+    PreserveLoggingContext, make_deferred_yieldable, run_in_background,
+)
+from synapse.util.metrics import Measure
+from synapse.api.errors import SynapseError
+
+from collections import namedtuple
+
+import logging
+import simplejson as json
+
+# these are only included to make the type annotations work
+from synapse.events import EventBase    # noqa: F401
+from synapse.events.snapshot import EventContext   # noqa: F401
+
+logger = logging.getLogger(__name__)
+
+
+# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# control how we batch/bulk fetch events from the database.
+# The values are plucked out of thing air to make initial sync run faster
+# on jki.re
+# TODO: Make these configurable.
+EVENT_QUEUE_THREADS = 3  # Max number of threads that will fetch events
+EVENT_QUEUE_ITERATIONS = 3  # No. times we block waiting for requests for events
+EVENT_QUEUE_TIMEOUT_S = 0.1  # Timeout when waiting for requests for events
+
+
+_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+
+
+class EventsWorkerStore(SQLBaseStore):
+    def get_received_ts(self, event_id):
+        """Get received_ts (when it was persisted) for the event.
+
+        Raises an exception for unknown events.
+
+        Args:
+            event_id (str)
+
+        Returns:
+            Deferred[int|None]: Timestamp in milliseconds, or None for events
+            that were persisted before received_ts was implemented.
+        """
+        return self._simple_select_one_onecol(
+            table="events",
+            keyvalues={
+                "event_id": event_id,
+            },
+            retcol="received_ts",
+            desc="get_received_ts",
+        )
+
+    @defer.inlineCallbacks
+    def get_event(self, event_id, check_redacted=True,
+                  get_prev_content=False, allow_rejected=False,
+                  allow_none=False):
+        """Get an event from the database by event_id.
+
+        Args:
+            event_id (str): The event_id of the event to fetch
+            check_redacted (bool): If True, check if event has been redacted
+                and redact it.
+            get_prev_content (bool): If True and event is a state event,
+                include the previous states content in the unsigned field.
+            allow_rejected (bool): If True return rejected events.
+            allow_none (bool): If True, return None if no event found, if
+                False throw an exception.
+
+        Returns:
+            Deferred : A FrozenEvent.
+        """
+        events = yield self._get_events(
+            [event_id],
+            check_redacted=check_redacted,
+            get_prev_content=get_prev_content,
+            allow_rejected=allow_rejected,
+        )
+
+        if not events and not allow_none:
+            raise SynapseError(404, "Could not find event %s" % (event_id,))
+
+        defer.returnValue(events[0] if events else None)
+
+    @defer.inlineCallbacks
+    def get_events(self, event_ids, check_redacted=True,
+                   get_prev_content=False, allow_rejected=False):
+        """Get events from the database
+
+        Args:
+            event_ids (list): The event_ids of the events to fetch
+            check_redacted (bool): If True, check if event has been redacted
+                and redact it.
+            get_prev_content (bool): If True and event is a state event,
+                include the previous states content in the unsigned field.
+            allow_rejected (bool): If True return rejected events.
+
+        Returns:
+            Deferred : Dict from event_id to event.
+        """
+        events = yield self._get_events(
+            event_ids,
+            check_redacted=check_redacted,
+            get_prev_content=get_prev_content,
+            allow_rejected=allow_rejected,
+        )
+
+        defer.returnValue({e.event_id: e for e in events})
+
+    @defer.inlineCallbacks
+    def _get_events(self, event_ids, check_redacted=True,
+                    get_prev_content=False, allow_rejected=False):
+        if not event_ids:
+            defer.returnValue([])
+
+        event_id_list = event_ids
+        event_ids = set(event_ids)
+
+        event_entry_map = self._get_events_from_cache(
+            event_ids,
+            allow_rejected=allow_rejected,
+        )
+
+        missing_events_ids = [e for e in event_ids if e not in event_entry_map]
+
+        if missing_events_ids:
+            missing_events = yield self._enqueue_events(
+                missing_events_ids,
+                check_redacted=check_redacted,
+                allow_rejected=allow_rejected,
+            )
+
+            event_entry_map.update(missing_events)
+
+        events = []
+        for event_id in event_id_list:
+            entry = event_entry_map.get(event_id, None)
+            if not entry:
+                continue
+
+            if allow_rejected or not entry.event.rejected_reason:
+                if check_redacted and entry.redacted_event:
+                    event = entry.redacted_event
+                else:
+                    event = entry.event
+
+                events.append(event)
+
+                if get_prev_content:
+                    if "replaces_state" in event.unsigned:
+                        prev = yield self.get_event(
+                            event.unsigned["replaces_state"],
+                            get_prev_content=False,
+                            allow_none=True,
+                        )
+                        if prev:
+                            event.unsigned = dict(event.unsigned)
+                            event.unsigned["prev_content"] = prev.content
+                            event.unsigned["prev_sender"] = prev.sender
+
+        defer.returnValue(events)
+
+    def _invalidate_get_event_cache(self, event_id):
+            self._get_event_cache.invalidate((event_id,))
+
+    def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
+        """Fetch events from the caches
+
+        Args:
+            events (list(str)): list of event_ids to fetch
+            allow_rejected (bool): Whether to teturn events that were rejected
+            update_metrics (bool): Whether to update the cache hit ratio metrics
+
+        Returns:
+            dict of event_id -> _EventCacheEntry for each event_id in cache. If
+            allow_rejected is `False` then there will still be an entry but it
+            will be `None`
+        """
+        event_map = {}
+
+        for event_id in events:
+            ret = self._get_event_cache.get(
+                (event_id,), None,
+                update_metrics=update_metrics,
+            )
+            if not ret:
+                continue
+
+            if allow_rejected or not ret.event.rejected_reason:
+                event_map[event_id] = ret
+            else:
+                event_map[event_id] = None
+
+        return event_map
+
+    def _do_fetch(self, conn):
+        """Takes a database connection and waits for requests for events from
+        the _event_fetch_list queue.
+        """
+        event_list = []
+        i = 0
+        while True:
+            try:
+                with self._event_fetch_lock:
+                    event_list = self._event_fetch_list
+                    self._event_fetch_list = []
+
+                    if not event_list:
+                        single_threaded = self.database_engine.single_threaded
+                        if single_threaded or i > EVENT_QUEUE_ITERATIONS:
+                            self._event_fetch_ongoing -= 1
+                            return
+                        else:
+                            self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+                            i += 1
+                            continue
+                    i = 0
+
+                event_id_lists = zip(*event_list)[0]
+                event_ids = [
+                    item for sublist in event_id_lists for item in sublist
+                ]
+
+                rows = self._new_transaction(
+                    conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids
+                )
+
+                row_dict = {
+                    r["event_id"]: r
+                    for r in rows
+                }
+
+                # We only want to resolve deferreds from the main thread
+                def fire(lst, res):
+                    for ids, d in lst:
+                        if not d.called:
+                            try:
+                                with PreserveLoggingContext():
+                                    d.callback([
+                                        res[i]
+                                        for i in ids
+                                        if i in res
+                                    ])
+                            except Exception:
+                                logger.exception("Failed to callback")
+                with PreserveLoggingContext():
+                    reactor.callFromThread(fire, event_list, row_dict)
+            except Exception as e:
+                logger.exception("do_fetch")
+
+                # We only want to resolve deferreds from the main thread
+                def fire(evs):
+                    for _, d in evs:
+                        if not d.called:
+                            with PreserveLoggingContext():
+                                d.errback(e)
+
+                if event_list:
+                    with PreserveLoggingContext():
+                        reactor.callFromThread(fire, event_list)
+
+    @defer.inlineCallbacks
+    def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
+        """Fetches events from the database using the _event_fetch_list. This
+        allows batch and bulk fetching of events - it allows us to fetch events
+        without having to create a new transaction for each request for events.
+        """
+        if not events:
+            defer.returnValue({})
+
+        events_d = defer.Deferred()
+        with self._event_fetch_lock:
+            self._event_fetch_list.append(
+                (events, events_d)
+            )
+
+            self._event_fetch_lock.notify()
+
+            if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
+                self._event_fetch_ongoing += 1
+                should_start = True
+            else:
+                should_start = False
+
+        if should_start:
+            with PreserveLoggingContext():
+                self.runWithConnection(
+                    self._do_fetch
+                )
+
+        logger.debug("Loading %d events", len(events))
+        with PreserveLoggingContext():
+            rows = yield events_d
+        logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
+
+        if not allow_rejected:
+            rows[:] = [r for r in rows if not r["rejects"]]
+
+        res = yield make_deferred_yieldable(defer.gatherResults(
+            [
+                run_in_background(
+                    self._get_event_from_row,
+                    row["internal_metadata"], row["json"], row["redacts"],
+                    rejected_reason=row["rejects"],
+                )
+                for row in rows
+            ],
+            consumeErrors=True
+        ))
+
+        defer.returnValue({
+            e.event.event_id: e
+            for e in res if e
+        })
+
+    def _fetch_event_rows(self, txn, events):
+        rows = []
+        N = 200
+        for i in range(1 + len(events) / N):
+            evs = events[i * N:(i + 1) * N]
+            if not evs:
+                break
+
+            sql = (
+                "SELECT "
+                " e.event_id as event_id, "
+                " e.internal_metadata,"
+                " e.json,"
+                " r.redacts as redacts,"
+                " rej.event_id as rejects "
+                " FROM event_json as e"
+                " LEFT JOIN rejections as rej USING (event_id)"
+                " LEFT JOIN redactions as r ON e.event_id = r.redacts"
+                " WHERE e.event_id IN (%s)"
+            ) % (",".join(["?"] * len(evs)),)
+
+            txn.execute(sql, evs)
+            rows.extend(self.cursor_to_dict(txn))
+
+        return rows
+
+    @defer.inlineCallbacks
+    def _get_event_from_row(self, internal_metadata, js, redacted,
+                            rejected_reason=None):
+        with Measure(self._clock, "_get_event_from_row"):
+            d = json.loads(js)
+            internal_metadata = json.loads(internal_metadata)
+
+            if rejected_reason:
+                rejected_reason = yield self._simple_select_one_onecol(
+                    table="rejections",
+                    keyvalues={"event_id": rejected_reason},
+                    retcol="reason",
+                    desc="_get_event_from_row_rejected_reason",
+                )
+
+            original_ev = FrozenEvent(
+                d,
+                internal_metadata_dict=internal_metadata,
+                rejected_reason=rejected_reason,
+            )
+
+            redacted_event = None
+            if redacted:
+                redacted_event = prune_event(original_ev)
+
+                redaction_id = yield self._simple_select_one_onecol(
+                    table="redactions",
+                    keyvalues={"redacts": redacted_event.event_id},
+                    retcol="event_id",
+                    desc="_get_event_from_row_redactions",
+                )
+
+                redacted_event.unsigned["redacted_by"] = redaction_id
+                # Get the redaction event.
+
+                because = yield self.get_event(
+                    redaction_id,
+                    check_redacted=False,
+                    allow_none=True,
+                )
+
+                if because:
+                    # It's fine to do add the event directly, since get_pdu_json
+                    # will serialise this field correctly
+                    redacted_event.unsigned["redacted_because"] = because
+
+            cache_entry = _EventCacheEntry(
+                event=original_ev,
+                redacted_event=redacted_event,
+            )
+
+            self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
+
+        defer.returnValue(cache_entry)
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index a2ccc66ea7..78b1e30945 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -19,6 +19,7 @@ from ._base import SQLBaseStore
 from synapse.api.errors import SynapseError, Codes
 from synapse.util.caches.descriptors import cachedInlineCallbacks
 
+from canonicaljson import encode_canonical_json
 import simplejson as json
 
 
@@ -46,12 +47,21 @@ class FilteringStore(SQLBaseStore):
         defer.returnValue(json.loads(str(def_json).decode("utf-8")))
 
     def add_user_filter(self, user_localpart, user_filter):
-        def_json = json.dumps(user_filter).encode("utf-8")
+        def_json = encode_canonical_json(user_filter)
 
         # Need an atomic transaction to SELECT the maximal ID so far then
         # INSERT a new one
         def _do_txn(txn):
             sql = (
+                "SELECT filter_id FROM user_filters "
+                "WHERE user_id = ? AND filter_json = ?"
+            )
+            txn.execute(sql, (user_localpart, def_json))
+            filter_id_response = txn.fetchone()
+            if filter_id_response is not None:
+                return filter_id_response[0]
+
+            sql = (
                 "SELECT MAX(filter_id) FROM user_filters "
                 "WHERE user_id = ?"
             )
diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py
new file mode 100644
index 0000000000..da05ccb027
--- /dev/null
+++ b/synapse/storage/group_server.py
@@ -0,0 +1,1253 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+
+from ._base import SQLBaseStore
+
+import simplejson as json
+
+
+# The category ID for the "default" category. We don't store as null in the
+# database to avoid the fun of null != null
+_DEFAULT_CATEGORY_ID = ""
+_DEFAULT_ROLE_ID = ""
+
+
+class GroupServerStore(SQLBaseStore):
+    def set_group_join_policy(self, group_id, join_policy):
+        """Set the join policy of a group.
+
+        join_policy can be one of:
+         * "invite"
+         * "open"
+        """
+        return self._simple_update_one(
+            table="groups",
+            keyvalues={
+                "group_id": group_id,
+            },
+            updatevalues={
+                "join_policy": join_policy,
+            },
+            desc="set_group_join_policy",
+        )
+
+    def get_group(self, group_id):
+        return self._simple_select_one(
+            table="groups",
+            keyvalues={
+                "group_id": group_id,
+            },
+            retcols=(
+                "name", "short_description", "long_description",
+                "avatar_url", "is_public", "join_policy",
+            ),
+            allow_none=True,
+            desc="get_group",
+        )
+
+    def get_users_in_group(self, group_id, include_private=False):
+        # TODO: Pagination
+
+        keyvalues = {
+            "group_id": group_id,
+        }
+        if not include_private:
+            keyvalues["is_public"] = True
+
+        return self._simple_select_list(
+            table="group_users",
+            keyvalues=keyvalues,
+            retcols=("user_id", "is_public", "is_admin",),
+            desc="get_users_in_group",
+        )
+
+    def get_invited_users_in_group(self, group_id):
+        # TODO: Pagination
+
+        return self._simple_select_onecol(
+            table="group_invites",
+            keyvalues={
+                "group_id": group_id,
+            },
+            retcol="user_id",
+            desc="get_invited_users_in_group",
+        )
+
+    def get_rooms_in_group(self, group_id, include_private=False):
+        # TODO: Pagination
+
+        keyvalues = {
+            "group_id": group_id,
+        }
+        if not include_private:
+            keyvalues["is_public"] = True
+
+        return self._simple_select_list(
+            table="group_rooms",
+            keyvalues=keyvalues,
+            retcols=("room_id", "is_public",),
+            desc="get_rooms_in_group",
+        )
+
+    def get_rooms_for_summary_by_category(self, group_id, include_private=False):
+        """Get the rooms and categories that should be included in a summary request
+
+        Returns ([rooms], [categories])
+        """
+        def _get_rooms_for_summary_txn(txn):
+            keyvalues = {
+                "group_id": group_id,
+            }
+            if not include_private:
+                keyvalues["is_public"] = True
+
+            sql = """
+                SELECT room_id, is_public, category_id, room_order
+                FROM group_summary_rooms
+                WHERE group_id = ?
+            """
+
+            if not include_private:
+                sql += " AND is_public = ?"
+                txn.execute(sql, (group_id, True))
+            else:
+                txn.execute(sql, (group_id,))
+
+            rooms = [
+                {
+                    "room_id": row[0],
+                    "is_public": row[1],
+                    "category_id": row[2] if row[2] != _DEFAULT_CATEGORY_ID else None,
+                    "order": row[3],
+                }
+                for row in txn
+            ]
+
+            sql = """
+                SELECT category_id, is_public, profile, cat_order
+                FROM group_summary_room_categories
+                INNER JOIN group_room_categories USING (group_id, category_id)
+                WHERE group_id = ?
+            """
+
+            if not include_private:
+                sql += " AND is_public = ?"
+                txn.execute(sql, (group_id, True))
+            else:
+                txn.execute(sql, (group_id,))
+
+            categories = {
+                row[0]: {
+                    "is_public": row[1],
+                    "profile": json.loads(row[2]),
+                    "order": row[3],
+                }
+                for row in txn
+            }
+
+            return rooms, categories
+        return self.runInteraction(
+            "get_rooms_for_summary", _get_rooms_for_summary_txn
+        )
+
+    def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
+        return self.runInteraction(
+            "add_room_to_summary", self._add_room_to_summary_txn,
+            group_id, room_id, category_id, order, is_public,
+        )
+
+    def _add_room_to_summary_txn(self, txn, group_id, room_id, category_id, order,
+                                 is_public):
+        """Add (or update) room's entry in summary.
+
+        Args:
+            group_id (str)
+            room_id (str)
+            category_id (str): If not None then adds the category to the end of
+                the summary if its not already there. [Optional]
+            order (int): If not None inserts the room at that position, e.g.
+                an order of 1 will put the room first. Otherwise, the room gets
+                added to the end.
+        """
+        room_in_group = self._simple_select_one_onecol_txn(
+            txn,
+            table="group_rooms",
+            keyvalues={
+                "group_id": group_id,
+                "room_id": room_id,
+            },
+            retcol="room_id",
+            allow_none=True,
+        )
+        if not room_in_group:
+            raise SynapseError(400, "room not in group")
+
+        if category_id is None:
+            category_id = _DEFAULT_CATEGORY_ID
+        else:
+            cat_exists = self._simple_select_one_onecol_txn(
+                txn,
+                table="group_room_categories",
+                keyvalues={
+                    "group_id": group_id,
+                    "category_id": category_id,
+                },
+                retcol="group_id",
+                allow_none=True,
+            )
+            if not cat_exists:
+                raise SynapseError(400, "Category doesn't exist")
+
+            # TODO: Check category is part of summary already
+            cat_exists = self._simple_select_one_onecol_txn(
+                txn,
+                table="group_summary_room_categories",
+                keyvalues={
+                    "group_id": group_id,
+                    "category_id": category_id,
+                },
+                retcol="group_id",
+                allow_none=True,
+            )
+            if not cat_exists:
+                # If not, add it with an order larger than all others
+                txn.execute("""
+                    INSERT INTO group_summary_room_categories
+                    (group_id, category_id, cat_order)
+                    SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1
+                    FROM group_summary_room_categories
+                    WHERE group_id = ? AND category_id = ?
+                """, (group_id, category_id, group_id, category_id))
+
+        existing = self._simple_select_one_txn(
+            txn,
+            table="group_summary_rooms",
+            keyvalues={
+                "group_id": group_id,
+                "room_id": room_id,
+                "category_id": category_id,
+            },
+            retcols=("room_order", "is_public",),
+            allow_none=True,
+        )
+
+        if order is not None:
+            # Shuffle other room orders that come after the given order
+            sql = """
+                UPDATE group_summary_rooms SET room_order = room_order + 1
+                WHERE group_id = ? AND category_id = ? AND room_order >= ?
+            """
+            txn.execute(sql, (group_id, category_id, order,))
+        elif not existing:
+            sql = """
+                SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms
+                WHERE group_id = ? AND category_id = ?
+            """
+            txn.execute(sql, (group_id, category_id,))
+            order, = txn.fetchone()
+
+        if existing:
+            to_update = {}
+            if order is not None:
+                to_update["room_order"] = order
+            if is_public is not None:
+                to_update["is_public"] = is_public
+            self._simple_update_txn(
+                txn,
+                table="group_summary_rooms",
+                keyvalues={
+                    "group_id": group_id,
+                    "category_id": category_id,
+                    "room_id": room_id,
+                },
+                values=to_update,
+            )
+        else:
+            if is_public is None:
+                is_public = True
+
+            self._simple_insert_txn(
+                txn,
+                table="group_summary_rooms",
+                values={
+                    "group_id": group_id,
+                    "category_id": category_id,
+                    "room_id": room_id,
+                    "room_order": order,
+                    "is_public": is_public,
+                },
+            )
+
+    def remove_room_from_summary(self, group_id, room_id, category_id):
+        if category_id is None:
+            category_id = _DEFAULT_CATEGORY_ID
+
+        return self._simple_delete(
+            table="group_summary_rooms",
+            keyvalues={
+                "group_id": group_id,
+                "category_id": category_id,
+                "room_id": room_id,
+            },
+            desc="remove_room_from_summary",
+        )
+
+    @defer.inlineCallbacks
+    def get_group_categories(self, group_id):
+        rows = yield self._simple_select_list(
+            table="group_room_categories",
+            keyvalues={
+                "group_id": group_id,
+            },
+            retcols=("category_id", "is_public", "profile"),
+            desc="get_group_categories",
+        )
+
+        defer.returnValue({
+            row["category_id"]: {
+                "is_public": row["is_public"],
+                "profile": json.loads(row["profile"]),
+            }
+            for row in rows
+        })
+
+    @defer.inlineCallbacks
+    def get_group_category(self, group_id, category_id):
+        category = yield self._simple_select_one(
+            table="group_room_categories",
+            keyvalues={
+                "group_id": group_id,
+                "category_id": category_id,
+            },
+            retcols=("is_public", "profile"),
+            desc="get_group_category",
+        )
+
+        category["profile"] = json.loads(category["profile"])
+
+        defer.returnValue(category)
+
+    def upsert_group_category(self, group_id, category_id, profile, is_public):
+        """Add/update room category for group
+        """
+        insertion_values = {}
+        update_values = {"category_id": category_id}  # This cannot be empty
+
+        if profile is None:
+            insertion_values["profile"] = "{}"
+        else:
+            update_values["profile"] = json.dumps(profile)
+
+        if is_public is None:
+            insertion_values["is_public"] = True
+        else:
+            update_values["is_public"] = is_public
+
+        return self._simple_upsert(
+            table="group_room_categories",
+            keyvalues={
+                "group_id": group_id,
+                "category_id": category_id,
+            },
+            values=update_values,
+            insertion_values=insertion_values,
+            desc="upsert_group_category",
+        )
+
+    def remove_group_category(self, group_id, category_id):
+        return self._simple_delete(
+            table="group_room_categories",
+            keyvalues={
+                "group_id": group_id,
+                "category_id": category_id,
+            },
+            desc="remove_group_category",
+        )
+
+    @defer.inlineCallbacks
+    def get_group_roles(self, group_id):
+        rows = yield self._simple_select_list(
+            table="group_roles",
+            keyvalues={
+                "group_id": group_id,
+            },
+            retcols=("role_id", "is_public", "profile"),
+            desc="get_group_roles",
+        )
+
+        defer.returnValue({
+            row["role_id"]: {
+                "is_public": row["is_public"],
+                "profile": json.loads(row["profile"]),
+            }
+            for row in rows
+        })
+
+    @defer.inlineCallbacks
+    def get_group_role(self, group_id, role_id):
+        role = yield self._simple_select_one(
+            table="group_roles",
+            keyvalues={
+                "group_id": group_id,
+                "role_id": role_id,
+            },
+            retcols=("is_public", "profile"),
+            desc="get_group_role",
+        )
+
+        role["profile"] = json.loads(role["profile"])
+
+        defer.returnValue(role)
+
+    def upsert_group_role(self, group_id, role_id, profile, is_public):
+        """Add/remove user role
+        """
+        insertion_values = {}
+        update_values = {"role_id": role_id}  # This cannot be empty
+
+        if profile is None:
+            insertion_values["profile"] = "{}"
+        else:
+            update_values["profile"] = json.dumps(profile)
+
+        if is_public is None:
+            insertion_values["is_public"] = True
+        else:
+            update_values["is_public"] = is_public
+
+        return self._simple_upsert(
+            table="group_roles",
+            keyvalues={
+                "group_id": group_id,
+                "role_id": role_id,
+            },
+            values=update_values,
+            insertion_values=insertion_values,
+            desc="upsert_group_role",
+        )
+
+    def remove_group_role(self, group_id, role_id):
+        return self._simple_delete(
+            table="group_roles",
+            keyvalues={
+                "group_id": group_id,
+                "role_id": role_id,
+            },
+            desc="remove_group_role",
+        )
+
+    def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
+        return self.runInteraction(
+            "add_user_to_summary", self._add_user_to_summary_txn,
+            group_id, user_id, role_id, order, is_public,
+        )
+
+    def _add_user_to_summary_txn(self, txn, group_id, user_id, role_id, order,
+                                 is_public):
+        """Add (or update) user's entry in summary.
+
+        Args:
+            group_id (str)
+            user_id (str)
+            role_id (str): If not None then adds the role to the end of
+                the summary if its not already there. [Optional]
+            order (int): If not None inserts the user at that position, e.g.
+                an order of 1 will put the user first. Otherwise, the user gets
+                added to the end.
+        """
+        user_in_group = self._simple_select_one_onecol_txn(
+            txn,
+            table="group_users",
+            keyvalues={
+                "group_id": group_id,
+                "user_id": user_id,
+            },
+            retcol="user_id",
+            allow_none=True,
+        )
+        if not user_in_group:
+            raise SynapseError(400, "user not in group")
+
+        if role_id is None:
+            role_id = _DEFAULT_ROLE_ID
+        else:
+            role_exists = self._simple_select_one_onecol_txn(
+                txn,
+                table="group_roles",
+                keyvalues={
+                    "group_id": group_id,
+                    "role_id": role_id,
+                },
+                retcol="group_id",
+                allow_none=True,
+            )
+            if not role_exists:
+                raise SynapseError(400, "Role doesn't exist")
+
+            # TODO: Check role is part of the summary already
+            role_exists = self._simple_select_one_onecol_txn(
+                txn,
+                table="group_summary_roles",
+                keyvalues={
+                    "group_id": group_id,
+                    "role_id": role_id,
+                },
+                retcol="group_id",
+                allow_none=True,
+            )
+            if not role_exists:
+                # If not, add it with an order larger than all others
+                txn.execute("""
+                    INSERT INTO group_summary_roles
+                    (group_id, role_id, role_order)
+                    SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1
+                    FROM group_summary_roles
+                    WHERE group_id = ? AND role_id = ?
+                """, (group_id, role_id, group_id, role_id))
+
+        existing = self._simple_select_one_txn(
+            txn,
+            table="group_summary_users",
+            keyvalues={
+                "group_id": group_id,
+                "user_id": user_id,
+                "role_id": role_id,
+            },
+            retcols=("user_order", "is_public",),
+            allow_none=True,
+        )
+
+        if order is not None:
+            # Shuffle other users orders that come after the given order
+            sql = """
+                UPDATE group_summary_users SET user_order = user_order + 1
+                WHERE group_id = ? AND role_id = ? AND user_order >= ?
+            """
+            txn.execute(sql, (group_id, role_id, order,))
+        elif not existing:
+            sql = """
+                SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users
+                WHERE group_id = ? AND role_id = ?
+            """
+            txn.execute(sql, (group_id, role_id,))
+            order, = txn.fetchone()
+
+        if existing:
+            to_update = {}
+            if order is not None:
+                to_update["user_order"] = order
+            if is_public is not None:
+                to_update["is_public"] = is_public
+            self._simple_update_txn(
+                txn,
+                table="group_summary_users",
+                keyvalues={
+                    "group_id": group_id,
+                    "role_id": role_id,
+                    "user_id": user_id,
+                },
+                values=to_update,
+            )
+        else:
+            if is_public is None:
+                is_public = True
+
+            self._simple_insert_txn(
+                txn,
+                table="group_summary_users",
+                values={
+                    "group_id": group_id,
+                    "role_id": role_id,
+                    "user_id": user_id,
+                    "user_order": order,
+                    "is_public": is_public,
+                },
+            )
+
+    def remove_user_from_summary(self, group_id, user_id, role_id):
+        if role_id is None:
+            role_id = _DEFAULT_ROLE_ID
+
+        return self._simple_delete(
+            table="group_summary_users",
+            keyvalues={
+                "group_id": group_id,
+                "role_id": role_id,
+                "user_id": user_id,
+            },
+            desc="remove_user_from_summary",
+        )
+
+    def get_users_for_summary_by_role(self, group_id, include_private=False):
+        """Get the users and roles that should be included in a summary request
+
+        Returns ([users], [roles])
+        """
+        def _get_users_for_summary_txn(txn):
+            keyvalues = {
+                "group_id": group_id,
+            }
+            if not include_private:
+                keyvalues["is_public"] = True
+
+            sql = """
+                SELECT user_id, is_public, role_id, user_order
+                FROM group_summary_users
+                WHERE group_id = ?
+            """
+
+            if not include_private:
+                sql += " AND is_public = ?"
+                txn.execute(sql, (group_id, True))
+            else:
+                txn.execute(sql, (group_id,))
+
+            users = [
+                {
+                    "user_id": row[0],
+                    "is_public": row[1],
+                    "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None,
+                    "order": row[3],
+                }
+                for row in txn
+            ]
+
+            sql = """
+                SELECT role_id, is_public, profile, role_order
+                FROM group_summary_roles
+                INNER JOIN group_roles USING (group_id, role_id)
+                WHERE group_id = ?
+            """
+
+            if not include_private:
+                sql += " AND is_public = ?"
+                txn.execute(sql, (group_id, True))
+            else:
+                txn.execute(sql, (group_id,))
+
+            roles = {
+                row[0]: {
+                    "is_public": row[1],
+                    "profile": json.loads(row[2]),
+                    "order": row[3],
+                }
+                for row in txn
+            }
+
+            return users, roles
+        return self.runInteraction(
+            "get_users_for_summary_by_role", _get_users_for_summary_txn
+        )
+
+    def is_user_in_group(self, user_id, group_id):
+        return self._simple_select_one_onecol(
+            table="group_users",
+            keyvalues={
+                "group_id": group_id,
+                "user_id": user_id,
+            },
+            retcol="user_id",
+            allow_none=True,
+            desc="is_user_in_group",
+        ).addCallback(lambda r: bool(r))
+
+    def is_user_admin_in_group(self, group_id, user_id):
+        return self._simple_select_one_onecol(
+            table="group_users",
+            keyvalues={
+                "group_id": group_id,
+                "user_id": user_id,
+            },
+            retcol="is_admin",
+            allow_none=True,
+            desc="is_user_admin_in_group",
+        )
+
+    def add_group_invite(self, group_id, user_id):
+        """Record that the group server has invited a user
+        """
+        return self._simple_insert(
+            table="group_invites",
+            values={
+                "group_id": group_id,
+                "user_id": user_id,
+            },
+            desc="add_group_invite",
+        )
+
+    def is_user_invited_to_local_group(self, group_id, user_id):
+        """Has the group server invited a user?
+        """
+        return self._simple_select_one_onecol(
+            table="group_invites",
+            keyvalues={
+                "group_id": group_id,
+                "user_id": user_id,
+            },
+            retcol="user_id",
+            desc="is_user_invited_to_local_group",
+            allow_none=True,
+        )
+
+    def get_users_membership_info_in_group(self, group_id, user_id):
+        """Get a dict describing the membership of a user in a group.
+
+        Example if joined:
+
+            {
+                "membership": "join",
+                "is_public": True,
+                "is_privileged": False,
+            }
+
+        Returns an empty dict if the user is not join/invite/etc
+        """
+        def _get_users_membership_in_group_txn(txn):
+            row = self._simple_select_one_txn(
+                txn,
+                table="group_users",
+                keyvalues={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                },
+                retcols=("is_admin", "is_public"),
+                allow_none=True,
+            )
+
+            if row:
+                return {
+                    "membership": "join",
+                    "is_public": row["is_public"],
+                    "is_privileged": row["is_admin"],
+                }
+
+            row = self._simple_select_one_onecol_txn(
+                txn,
+                table="group_invites",
+                keyvalues={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                },
+                retcol="user_id",
+                allow_none=True,
+            )
+
+            if row:
+                return {
+                    "membership": "invite",
+                }
+
+            return {}
+
+        return self.runInteraction(
+            "get_users_membership_info_in_group", _get_users_membership_in_group_txn,
+        )
+
+    def add_user_to_group(self, group_id, user_id, is_admin=False, is_public=True,
+                          local_attestation=None, remote_attestation=None):
+        """Add a user to the group server.
+
+        Args:
+            group_id (str)
+            user_id (str)
+            is_admin (bool)
+            is_public (bool)
+            local_attestation (dict): The attestation the GS created to give
+                to the remote server. Optional if the user and group are on the
+                same server
+            remote_attestation (dict): The attestation given to GS by remote
+                server. Optional if the user and group are on the same server
+        """
+        def _add_user_to_group_txn(txn):
+            self._simple_insert_txn(
+                txn,
+                table="group_users",
+                values={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                    "is_admin": is_admin,
+                    "is_public": is_public,
+                },
+            )
+
+            self._simple_delete_txn(
+                txn,
+                table="group_invites",
+                keyvalues={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                },
+            )
+
+            if local_attestation:
+                self._simple_insert_txn(
+                    txn,
+                    table="group_attestations_renewals",
+                    values={
+                        "group_id": group_id,
+                        "user_id": user_id,
+                        "valid_until_ms": local_attestation["valid_until_ms"],
+                    },
+                )
+            if remote_attestation:
+                self._simple_insert_txn(
+                    txn,
+                    table="group_attestations_remote",
+                    values={
+                        "group_id": group_id,
+                        "user_id": user_id,
+                        "valid_until_ms": remote_attestation["valid_until_ms"],
+                        "attestation_json": json.dumps(remote_attestation),
+                    },
+                )
+
+        return self.runInteraction(
+            "add_user_to_group", _add_user_to_group_txn
+        )
+
+    def remove_user_from_group(self, group_id, user_id):
+        def _remove_user_from_group_txn(txn):
+            self._simple_delete_txn(
+                txn,
+                table="group_users",
+                keyvalues={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                },
+            )
+            self._simple_delete_txn(
+                txn,
+                table="group_invites",
+                keyvalues={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                },
+            )
+            self._simple_delete_txn(
+                txn,
+                table="group_attestations_renewals",
+                keyvalues={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                },
+            )
+            self._simple_delete_txn(
+                txn,
+                table="group_attestations_remote",
+                keyvalues={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                },
+            )
+            self._simple_delete_txn(
+                txn,
+                table="group_summary_users",
+                keyvalues={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                },
+            )
+        return self.runInteraction("remove_user_from_group", _remove_user_from_group_txn)
+
+    def add_room_to_group(self, group_id, room_id, is_public):
+        return self._simple_insert(
+            table="group_rooms",
+            values={
+                "group_id": group_id,
+                "room_id": room_id,
+                "is_public": is_public,
+            },
+            desc="add_room_to_group",
+        )
+
+    def update_room_in_group_visibility(self, group_id, room_id, is_public):
+        return self._simple_update(
+            table="group_rooms",
+            keyvalues={
+                "group_id": group_id,
+                "room_id": room_id,
+            },
+            updatevalues={
+                "is_public": is_public,
+            },
+            desc="update_room_in_group_visibility",
+        )
+
+    def remove_room_from_group(self, group_id, room_id):
+        def _remove_room_from_group_txn(txn):
+            self._simple_delete_txn(
+                txn,
+                table="group_rooms",
+                keyvalues={
+                    "group_id": group_id,
+                    "room_id": room_id,
+                },
+            )
+
+            self._simple_delete_txn(
+                txn,
+                table="group_summary_rooms",
+                keyvalues={
+                    "group_id": group_id,
+                    "room_id": room_id,
+                },
+            )
+        return self.runInteraction(
+            "remove_room_from_group", _remove_room_from_group_txn,
+        )
+
+    def get_publicised_groups_for_user(self, user_id):
+        """Get all groups a user is publicising
+        """
+        return self._simple_select_onecol(
+            table="local_group_membership",
+            keyvalues={
+                "user_id": user_id,
+                "membership": "join",
+                "is_publicised": True,
+            },
+            retcol="group_id",
+            desc="get_publicised_groups_for_user",
+        )
+
+    def update_group_publicity(self, group_id, user_id, publicise):
+        """Update whether the user is publicising their membership of the group
+        """
+        return self._simple_update_one(
+            table="local_group_membership",
+            keyvalues={
+                "group_id": group_id,
+                "user_id": user_id,
+            },
+            updatevalues={
+                "is_publicised": publicise,
+            },
+            desc="update_group_publicity"
+        )
+
+    @defer.inlineCallbacks
+    def register_user_group_membership(self, group_id, user_id, membership,
+                                       is_admin=False, content={},
+                                       local_attestation=None,
+                                       remote_attestation=None,
+                                       is_publicised=False,
+                                       ):
+        """Registers that a local user is a member of a (local or remote) group.
+
+        Args:
+            group_id (str)
+            user_id (str)
+            membership (str)
+            is_admin (bool)
+            content (dict): Content of the membership, e.g. includes the inviter
+                if the user has been invited.
+            local_attestation (dict): If remote group then store the fact that we
+                have given out an attestation, else None.
+            remote_attestation (dict): If remote group then store the remote
+                attestation from the group, else None.
+        """
+        def _register_user_group_membership_txn(txn, next_id):
+            # TODO: Upsert?
+            self._simple_delete_txn(
+                txn,
+                table="local_group_membership",
+                keyvalues={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                },
+            )
+            self._simple_insert_txn(
+                txn,
+                table="local_group_membership",
+                values={
+                    "group_id": group_id,
+                    "user_id": user_id,
+                    "is_admin": is_admin,
+                    "membership": membership,
+                    "is_publicised": is_publicised,
+                    "content": json.dumps(content),
+                },
+            )
+
+            self._simple_insert_txn(
+                txn,
+                table="local_group_updates",
+                values={
+                    "stream_id": next_id,
+                    "group_id": group_id,
+                    "user_id": user_id,
+                    "type": "membership",
+                    "content": json.dumps({"membership": membership, "content": content}),
+                }
+            )
+            self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
+
+            # TODO: Insert profile to ensure it comes down stream if its a join.
+
+            if membership == "join":
+                if local_attestation:
+                    self._simple_insert_txn(
+                        txn,
+                        table="group_attestations_renewals",
+                        values={
+                            "group_id": group_id,
+                            "user_id": user_id,
+                            "valid_until_ms": local_attestation["valid_until_ms"],
+                        }
+                    )
+                if remote_attestation:
+                    self._simple_insert_txn(
+                        txn,
+                        table="group_attestations_remote",
+                        values={
+                            "group_id": group_id,
+                            "user_id": user_id,
+                            "valid_until_ms": remote_attestation["valid_until_ms"],
+                            "attestation_json": json.dumps(remote_attestation),
+                        }
+                    )
+            else:
+                self._simple_delete_txn(
+                    txn,
+                    table="group_attestations_renewals",
+                    keyvalues={
+                        "group_id": group_id,
+                        "user_id": user_id,
+                    },
+                )
+                self._simple_delete_txn(
+                    txn,
+                    table="group_attestations_remote",
+                    keyvalues={
+                        "group_id": group_id,
+                        "user_id": user_id,
+                    },
+                )
+
+            return next_id
+
+        with self._group_updates_id_gen.get_next() as next_id:
+            res = yield self.runInteraction(
+                "register_user_group_membership",
+                _register_user_group_membership_txn, next_id,
+            )
+        defer.returnValue(res)
+
+    @defer.inlineCallbacks
+    def create_group(self, group_id, user_id, name, avatar_url, short_description,
+                     long_description,):
+        yield self._simple_insert(
+            table="groups",
+            values={
+                "group_id": group_id,
+                "name": name,
+                "avatar_url": avatar_url,
+                "short_description": short_description,
+                "long_description": long_description,
+                "is_public": True,
+            },
+            desc="create_group",
+        )
+
+    @defer.inlineCallbacks
+    def update_group_profile(self, group_id, profile,):
+        yield self._simple_update_one(
+            table="groups",
+            keyvalues={
+                "group_id": group_id,
+            },
+            updatevalues=profile,
+            desc="update_group_profile",
+        )
+
+    def get_attestations_need_renewals(self, valid_until_ms):
+        """Get all attestations that need to be renewed until givent time
+        """
+        def _get_attestations_need_renewals_txn(txn):
+            sql = """
+                SELECT group_id, user_id FROM group_attestations_renewals
+                WHERE valid_until_ms <= ?
+            """
+            txn.execute(sql, (valid_until_ms,))
+            return self.cursor_to_dict(txn)
+        return self.runInteraction(
+            "get_attestations_need_renewals", _get_attestations_need_renewals_txn
+        )
+
+    def update_attestation_renewal(self, group_id, user_id, attestation):
+        """Update an attestation that we have renewed
+        """
+        return self._simple_update_one(
+            table="group_attestations_renewals",
+            keyvalues={
+                "group_id": group_id,
+                "user_id": user_id,
+            },
+            updatevalues={
+                "valid_until_ms": attestation["valid_until_ms"],
+            },
+            desc="update_attestation_renewal",
+        )
+
+    def update_remote_attestion(self, group_id, user_id, attestation):
+        """Update an attestation that a remote has renewed
+        """
+        return self._simple_update_one(
+            table="group_attestations_remote",
+            keyvalues={
+                "group_id": group_id,
+                "user_id": user_id,
+            },
+            updatevalues={
+                "valid_until_ms": attestation["valid_until_ms"],
+                "attestation_json": json.dumps(attestation)
+            },
+            desc="update_remote_attestion",
+        )
+
+    def remove_attestation_renewal(self, group_id, user_id):
+        """Remove an attestation that we thought we should renew, but actually
+        shouldn't. Ideally this would never get called as we would never
+        incorrectly try and do attestations for local users on local groups.
+
+        Args:
+            group_id (str)
+            user_id (str)
+        """
+        return self._simple_delete(
+            table="group_attestations_renewals",
+            keyvalues={
+                "group_id": group_id,
+                "user_id": user_id,
+            },
+            desc="remove_attestation_renewal",
+        )
+
+    @defer.inlineCallbacks
+    def get_remote_attestation(self, group_id, user_id):
+        """Get the attestation that proves the remote agrees that the user is
+        in the group.
+        """
+        row = yield self._simple_select_one(
+            table="group_attestations_remote",
+            keyvalues={
+                "group_id": group_id,
+                "user_id": user_id,
+            },
+            retcols=("valid_until_ms", "attestation_json"),
+            desc="get_remote_attestation",
+            allow_none=True,
+        )
+
+        now = int(self._clock.time_msec())
+        if row and now < row["valid_until_ms"]:
+            defer.returnValue(json.loads(row["attestation_json"]))
+
+        defer.returnValue(None)
+
+    def get_joined_groups(self, user_id):
+        return self._simple_select_onecol(
+            table="local_group_membership",
+            keyvalues={
+                "user_id": user_id,
+                "membership": "join",
+            },
+            retcol="group_id",
+            desc="get_joined_groups",
+        )
+
+    def get_all_groups_for_user(self, user_id, now_token):
+        def _get_all_groups_for_user_txn(txn):
+            sql = """
+                SELECT group_id, type, membership, u.content
+                FROM local_group_updates AS u
+                INNER JOIN local_group_membership USING (group_id, user_id)
+                WHERE user_id = ? AND membership != 'leave'
+                    AND stream_id <= ?
+            """
+            txn.execute(sql, (user_id, now_token,))
+            return [
+                {
+                    "group_id": row[0],
+                    "type": row[1],
+                    "membership": row[2],
+                    "content": json.loads(row[3]),
+                }
+                for row in txn
+            ]
+        return self.runInteraction(
+            "get_all_groups_for_user", _get_all_groups_for_user_txn,
+        )
+
+    def get_groups_changes_for_user(self, user_id, from_token, to_token):
+        from_token = int(from_token)
+        has_changed = self._group_updates_stream_cache.has_entity_changed(
+            user_id, from_token,
+        )
+        if not has_changed:
+            return []
+
+        def _get_groups_changes_for_user_txn(txn):
+            sql = """
+                SELECT group_id, membership, type, u.content
+                FROM local_group_updates AS u
+                INNER JOIN local_group_membership USING (group_id, user_id)
+                WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
+            """
+            txn.execute(sql, (user_id, from_token, to_token,))
+            return [{
+                "group_id": group_id,
+                "membership": membership,
+                "type": gtype,
+                "content": json.loads(content_json),
+            } for group_id, membership, gtype, content_json in txn]
+        return self.runInteraction(
+            "get_groups_changes_for_user", _get_groups_changes_for_user_txn,
+        )
+
+    def get_all_groups_changes(self, from_token, to_token, limit):
+        from_token = int(from_token)
+        has_changed = self._group_updates_stream_cache.has_any_entity_changed(
+            from_token,
+        )
+        if not has_changed:
+            return []
+
+        def _get_all_groups_changes_txn(txn):
+            sql = """
+                SELECT stream_id, group_id, user_id, type, content
+                FROM local_group_updates
+                WHERE ? < stream_id AND stream_id <= ?
+                LIMIT ?
+            """
+            txn.execute(sql, (from_token, to_token, limit,))
+            return [(
+                stream_id,
+                group_id,
+                user_id,
+                gtype,
+                json.loads(content_json),
+            ) for stream_id, group_id, user_id, gtype, content_json in txn]
+        return self.runInteraction(
+            "get_all_groups_changes", _get_all_groups_changes_txn,
+        )
+
+    def get_group_stream_token(self):
+        return self._group_updates_id_gen.get_current_token()
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 86b37b9ddd..87aeaf71d6 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -101,9 +101,10 @@ class KeyStore(SQLBaseStore):
         key_ids
         Args:
             server_name (str): The name of the server.
-            key_ids (list of str): List of key_ids to try and look up.
+            key_ids (iterable[str]): key_ids to try and look up.
         Returns:
-            (list of VerifyKey): The verification keys.
+            Deferred: resolves to dict[str, VerifyKey]: map from
+               key_id to verification key.
         """
         keys = {}
         for key_id in key_ids:
@@ -112,30 +113,37 @@ class KeyStore(SQLBaseStore):
                 keys[key_id] = key
         defer.returnValue(keys)
 
-    @defer.inlineCallbacks
     def store_server_verify_key(self, server_name, from_server, time_now_ms,
                                 verify_key):
         """Stores a NACL verification key for the given server.
         Args:
             server_name (str): The name of the server.
-            key_id (str): The version of the key for the server.
             from_server (str): Where the verification key was looked up
-            ts_now_ms (int): The time now in milliseconds
-            verification_key (VerifyKey): The NACL verify key.
+            time_now_ms (int): The time now in milliseconds
+            verify_key (nacl.signing.VerifyKey): The NACL verify key.
         """
-        yield self._simple_upsert(
-            table="server_signature_keys",
-            keyvalues={
-                "server_name": server_name,
-                "key_id": "%s:%s" % (verify_key.alg, verify_key.version),
-            },
-            values={
-                "from_server": from_server,
-                "ts_added_ms": time_now_ms,
-                "verify_key": buffer(verify_key.encode()),
-            },
-            desc="store_server_verify_key",
-        )
+        key_id = "%s:%s" % (verify_key.alg, verify_key.version)
+
+        def _txn(txn):
+            self._simple_upsert_txn(
+                txn,
+                table="server_signature_keys",
+                keyvalues={
+                    "server_name": server_name,
+                    "key_id": key_id,
+                },
+                values={
+                    "from_server": from_server,
+                    "ts_added_ms": time_now_ms,
+                    "verify_key": buffer(verify_key.encode()),
+                },
+            )
+            txn.call_after(
+                self._get_server_verify_key.invalidate,
+                (server_name, key_id)
+            )
+
+        return self.runInteraction("store_server_verify_key", _txn)
 
     def store_server_keys_json(self, server_name, key_id, from_server,
                                ts_now_ms, ts_expires_ms, key_json_bytes):
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index 4c0f82353d..e6cdbb0545 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -12,15 +12,22 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from synapse.storage.background_updates import BackgroundUpdateStore
 
-from ._base import SQLBaseStore
 
-
-class MediaRepositoryStore(SQLBaseStore):
+class MediaRepositoryStore(BackgroundUpdateStore):
     """Persistence for attachments and avatars"""
 
-    def get_default_thumbnails(self, top_level_type, sub_type):
-        return []
+    def __init__(self, db_conn, hs):
+        super(MediaRepositoryStore, self).__init__(db_conn, hs)
+
+        self.register_background_index_update(
+            update_name='local_media_repository_url_idx',
+            index_name='local_media_repository_url_idx',
+            table='local_media_repository',
+            columns=['created_ts'],
+            where_clause='url_cache IS NOT NULL',
+        )
 
     def get_local_media(self, media_id):
         """Get the metadata for a local piece of media
@@ -30,13 +37,16 @@ class MediaRepositoryStore(SQLBaseStore):
         return self._simple_select_one(
             "local_media_repository",
             {"media_id": media_id},
-            ("media_type", "media_length", "upload_name", "created_ts"),
+            (
+                "media_type", "media_length", "upload_name", "created_ts",
+                "quarantined_by", "url_cache",
+            ),
             allow_none=True,
             desc="get_local_media",
         )
 
     def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
-                          media_length, user_id):
+                          media_length, user_id, url_cache=None):
         return self._simple_insert(
             "local_media_repository",
             {
@@ -46,6 +56,7 @@ class MediaRepositoryStore(SQLBaseStore):
                 "upload_name": upload_name,
                 "media_length": media_length,
                 "user_id": user_id.to_string(),
+                "url_cache": url_cache,
             },
             desc="store_local_media",
         )
@@ -58,7 +69,7 @@ class MediaRepositoryStore(SQLBaseStore):
         def get_url_cache_txn(txn):
             # get the most recently cached result (relative to the given ts)
             sql = (
-                "SELECT response_code, etag, expires, og, media_id, download_ts"
+                "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
                 " FROM local_media_repository_url_cache"
                 " WHERE url = ? AND download_ts <= ?"
                 " ORDER BY download_ts DESC LIMIT 1"
@@ -70,7 +81,7 @@ class MediaRepositoryStore(SQLBaseStore):
                 # ...or if we've requested a timestamp older than the oldest
                 # copy in the cache, return the oldest copy (if any)
                 sql = (
-                    "SELECT response_code, etag, expires, og, media_id, download_ts"
+                    "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
                     " FROM local_media_repository_url_cache"
                     " WHERE url = ? AND download_ts > ?"
                     " ORDER BY download_ts ASC LIMIT 1"
@@ -82,14 +93,14 @@ class MediaRepositoryStore(SQLBaseStore):
                 return None
 
             return dict(zip((
-                'response_code', 'etag', 'expires', 'og', 'media_id', 'download_ts'
+                'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts'
             ), row))
 
         return self.runInteraction(
             "get_url_cache", get_url_cache_txn
         )
 
-    def store_url_cache(self, url, response_code, etag, expires, og, media_id,
+    def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id,
                         download_ts):
         return self._simple_insert(
             "local_media_repository_url_cache",
@@ -97,7 +108,7 @@ class MediaRepositoryStore(SQLBaseStore):
                 "url": url,
                 "response_code": response_code,
                 "etag": etag,
-                "expires": expires,
+                "expires_ts": expires_ts,
                 "og": og,
                 "media_id": media_id,
                 "download_ts": download_ts,
@@ -138,7 +149,7 @@ class MediaRepositoryStore(SQLBaseStore):
             {"media_origin": origin, "media_id": media_id},
             (
                 "media_type", "media_length", "upload_name", "created_ts",
-                "filesystem_id",
+                "filesystem_id", "quarantined_by",
             ),
             allow_none=True,
             desc="get_cached_remote_media",
@@ -162,7 +173,14 @@ class MediaRepositoryStore(SQLBaseStore):
             desc="store_cached_remote_media",
         )
 
-    def update_cached_last_access_time(self, origin_id_tuples, time_ts):
+    def update_cached_last_access_time(self, local_media, remote_media, time_ms):
+        """Updates the last access time of the given media
+
+        Args:
+            local_media (iterable[str]): Set of media_ids
+            remote_media (iterable[(str, str)]): Set of (server_name, media_id)
+            time_ms: Current time in milliseconds
+        """
         def update_cache_txn(txn):
             sql = (
                 "UPDATE remote_media_cache SET last_access_ts = ?"
@@ -170,8 +188,18 @@ class MediaRepositoryStore(SQLBaseStore):
             )
 
             txn.executemany(sql, (
-                (time_ts, media_origin, media_id)
-                for media_origin, media_id in origin_id_tuples
+                (time_ms, media_origin, media_id)
+                for media_origin, media_id in remote_media
+            ))
+
+            sql = (
+                "UPDATE local_media_repository SET last_access_ts = ?"
+                " WHERE media_id = ?"
+            )
+
+            txn.executemany(sql, (
+                (time_ms, media_id)
+                for media_id in local_media
             ))
 
         return self.runInteraction("update_cached_last_access_time", update_cache_txn)
@@ -234,3 +262,70 @@ class MediaRepositoryStore(SQLBaseStore):
                 },
             )
         return self.runInteraction("delete_remote_media", delete_remote_media_txn)
+
+    def get_expired_url_cache(self, now_ts):
+        sql = (
+            "SELECT media_id FROM local_media_repository_url_cache"
+            " WHERE expires_ts < ?"
+            " ORDER BY expires_ts ASC"
+            " LIMIT 500"
+        )
+
+        def _get_expired_url_cache_txn(txn):
+            txn.execute(sql, (now_ts,))
+            return [row[0] for row in txn]
+
+        return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
+
+    def delete_url_cache(self, media_ids):
+        if len(media_ids) == 0:
+            return
+
+        sql = (
+            "DELETE FROM local_media_repository_url_cache"
+            " WHERE media_id = ?"
+        )
+
+        def _delete_url_cache_txn(txn):
+            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+
+        return self.runInteraction("delete_url_cache", _delete_url_cache_txn)
+
+    def get_url_cache_media_before(self, before_ts):
+        sql = (
+            "SELECT media_id FROM local_media_repository"
+            " WHERE created_ts < ? AND url_cache IS NOT NULL"
+            " ORDER BY created_ts ASC"
+            " LIMIT 500"
+        )
+
+        def _get_url_cache_media_before_txn(txn):
+            txn.execute(sql, (before_ts,))
+            return [row[0] for row in txn]
+
+        return self.runInteraction(
+            "get_url_cache_media_before", _get_url_cache_media_before_txn,
+        )
+
+    def delete_url_cache_media(self, media_ids):
+        if len(media_ids) == 0:
+            return
+
+        def _delete_url_cache_media_txn(txn):
+            sql = (
+                "DELETE FROM local_media_repository"
+                " WHERE media_id = ?"
+            )
+
+            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+
+            sql = (
+                "DELETE FROM local_media_repository_thumbnails"
+                " WHERE media_id = ?"
+            )
+
+            txn.executemany(sql, [(media_id,) for media_id in media_ids])
+
+        return self.runInteraction(
+            "delete_url_cache_media", _delete_url_cache_media_txn,
+        )
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index b357f22be7..04411a665f 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -25,7 +26,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 40
+SCHEMA_VERSION = 48
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
@@ -44,6 +45,13 @@ def prepare_database(db_conn, database_engine, config):
 
     If `config` is None then prepare_database will assert that no upgrade is
     necessary, *or* will create a fresh database if the database is empty.
+
+    Args:
+        db_conn:
+        database_engine:
+        config (synapse.config.homeserver.HomeServerConfig|None):
+            application config, or None if we are connecting to an existing
+            database which we expect to be configured already
     """
     try:
         cur = db_conn.cursor()
@@ -64,9 +72,13 @@ def prepare_database(db_conn, database_engine, config):
         else:
             _setup_new_database(cur, database_engine)
 
+        # check if any of our configured dynamic modules want a database
+        if config is not None:
+            _apply_module_schemas(cur, database_engine, config)
+
         cur.close()
         db_conn.commit()
-    except:
+    except Exception:
         db_conn.rollback()
         raise
 
@@ -283,6 +295,65 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
             )
 
 
+def _apply_module_schemas(txn, database_engine, config):
+    """Apply the module schemas for the dynamic modules, if any
+
+    Args:
+        cur: database cursor
+        database_engine: synapse database engine class
+        config (synapse.config.homeserver.HomeServerConfig):
+            application config
+    """
+    for (mod, _config) in config.password_providers:
+        if not hasattr(mod, 'get_db_schema_files'):
+            continue
+        modname = ".".join((mod.__module__, mod.__name__))
+        _apply_module_schema_files(
+            txn, database_engine, modname, mod.get_db_schema_files(),
+        )
+
+
+def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
+    """Apply the module schemas for a single module
+
+    Args:
+        cur: database cursor
+        database_engine: synapse database engine class
+        modname (str): fully qualified name of the module
+        names_and_streams (Iterable[(str, file)]): the names and streams of
+            schemas to be applied
+    """
+    cur.execute(
+        database_engine.convert_param_style(
+            "SELECT file FROM applied_module_schemas WHERE module_name = ?"
+        ),
+        (modname,)
+    )
+    applied_deltas = set(d for d, in cur)
+    for (name, stream) in names_and_streams:
+        if name in applied_deltas:
+            continue
+
+        root_name, ext = os.path.splitext(name)
+        if ext != '.sql':
+            raise PrepareDatabaseException(
+                "only .sql files are currently supported for module schemas",
+            )
+
+        logger.info("applying schema %s for %s", name, modname)
+        for statement in get_statements(stream):
+            cur.execute(statement)
+
+        # Mark as done.
+        cur.execute(
+            database_engine.convert_param_style(
+                "INSERT INTO applied_module_schemas (module_name, file)"
+                " VALUES (?,?)",
+            ),
+            (modname, name)
+        )
+
+
 def get_statements(f):
     statement_buffer = ""
     in_comment = False  # If we're in a /* ... */ style comment
@@ -356,7 +427,7 @@ def _get_or_create_schema_state(txn, database_engine):
             ),
             (current_version,)
         )
-        applied_deltas = [d for d, in txn.fetchall()]
+        applied_deltas = [d for d, in txn]
         return current_version, applied_deltas, upgraded
 
     return None
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 4d1590d2b4..9e9d3c2591 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -85,8 +85,8 @@ class PresenceStore(SQLBaseStore):
                 self.presence_stream_cache.entity_has_changed,
                 state.user_id, stream_id,
             )
-            self._invalidate_cache_and_stream(
-                txn, self._get_presence_for_user, (state.user_id,)
+            txn.call_after(
+                self._get_presence_for_user.invalidate, (state.user_id,)
             )
 
         # Actually insert new rows
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index 26a40905ae..8612bd5ecc 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -13,15 +13,37 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
+
+from synapse.storage.roommember import ProfileInfo
+from synapse.api.errors import StoreError
+
 from ._base import SQLBaseStore
 
 
-class ProfileStore(SQLBaseStore):
-    def create_profile(self, user_localpart):
-        return self._simple_insert(
-            table="profiles",
-            values={"user_id": user_localpart},
-            desc="create_profile",
+class ProfileWorkerStore(SQLBaseStore):
+    @defer.inlineCallbacks
+    def get_profileinfo(self, user_localpart):
+        try:
+            profile = yield self._simple_select_one(
+                table="profiles",
+                keyvalues={"user_id": user_localpart},
+                retcols=("displayname", "avatar_url"),
+                desc="get_profileinfo",
+            )
+        except StoreError as e:
+            if e.code == 404:
+                # no match
+                defer.returnValue(ProfileInfo(None, None))
+                return
+            else:
+                raise
+
+        defer.returnValue(
+            ProfileInfo(
+                avatar_url=profile['avatar_url'],
+                display_name=profile['displayname'],
+            )
         )
 
     def get_profile_displayname(self, user_localpart):
@@ -32,14 +54,6 @@ class ProfileStore(SQLBaseStore):
             desc="get_profile_displayname",
         )
 
-    def set_profile_displayname(self, user_localpart, new_displayname):
-        return self._simple_update_one(
-            table="profiles",
-            keyvalues={"user_id": user_localpart},
-            updatevalues={"displayname": new_displayname},
-            desc="set_profile_displayname",
-        )
-
     def get_profile_avatar_url(self, user_localpart):
         return self._simple_select_one_onecol(
             table="profiles",
@@ -48,6 +62,32 @@ class ProfileStore(SQLBaseStore):
             desc="get_profile_avatar_url",
         )
 
+    def get_from_remote_profile_cache(self, user_id):
+        return self._simple_select_one(
+            table="remote_profile_cache",
+            keyvalues={"user_id": user_id},
+            retcols=("displayname", "avatar_url",),
+            allow_none=True,
+            desc="get_from_remote_profile_cache",
+        )
+
+
+class ProfileStore(ProfileWorkerStore):
+    def create_profile(self, user_localpart):
+        return self._simple_insert(
+            table="profiles",
+            values={"user_id": user_localpart},
+            desc="create_profile",
+        )
+
+    def set_profile_displayname(self, user_localpart, new_displayname):
+        return self._simple_update_one(
+            table="profiles",
+            keyvalues={"user_id": user_localpart},
+            updatevalues={"displayname": new_displayname},
+            desc="set_profile_displayname",
+        )
+
     def set_profile_avatar_url(self, user_localpart, new_avatar_url):
         return self._simple_update_one(
             table="profiles",
@@ -55,3 +95,90 @@ class ProfileStore(SQLBaseStore):
             updatevalues={"avatar_url": new_avatar_url},
             desc="set_profile_avatar_url",
         )
+
+    def add_remote_profile_cache(self, user_id, displayname, avatar_url):
+        """Ensure we are caching the remote user's profiles.
+
+        This should only be called when `is_subscribed_remote_profile_for_user`
+        would return true for the user.
+        """
+        return self._simple_upsert(
+            table="remote_profile_cache",
+            keyvalues={"user_id": user_id},
+            values={
+                "displayname": displayname,
+                "avatar_url": avatar_url,
+                "last_check": self._clock.time_msec(),
+            },
+            desc="add_remote_profile_cache",
+        )
+
+    def update_remote_profile_cache(self, user_id, displayname, avatar_url):
+        return self._simple_update(
+            table="remote_profile_cache",
+            keyvalues={"user_id": user_id},
+            values={
+                "displayname": displayname,
+                "avatar_url": avatar_url,
+                "last_check": self._clock.time_msec(),
+            },
+            desc="update_remote_profile_cache",
+        )
+
+    @defer.inlineCallbacks
+    def maybe_delete_remote_profile_cache(self, user_id):
+        """Check if we still care about the remote user's profile, and if we
+        don't then remove their profile from the cache
+        """
+        subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
+        if not subscribed:
+            yield self._simple_delete(
+                table="remote_profile_cache",
+                keyvalues={"user_id": user_id},
+                desc="delete_remote_profile_cache",
+            )
+
+    def get_remote_profile_cache_entries_that_expire(self, last_checked):
+        """Get all users who haven't been checked since `last_checked`
+        """
+        def _get_remote_profile_cache_entries_that_expire_txn(txn):
+            sql = """
+                SELECT user_id, displayname, avatar_url
+                FROM remote_profile_cache
+                WHERE last_check < ?
+            """
+
+            txn.execute(sql, (last_checked,))
+
+            return self.cursor_to_dict(txn)
+
+        return self.runInteraction(
+            "get_remote_profile_cache_entries_that_expire",
+            _get_remote_profile_cache_entries_that_expire_txn,
+        )
+
+    @defer.inlineCallbacks
+    def is_subscribed_remote_profile_for_user(self, user_id):
+        """Check whether we are interested in a remote user's profile.
+        """
+        res = yield self._simple_select_one_onecol(
+            table="group_users",
+            keyvalues={"user_id": user_id},
+            retcol="user_id",
+            allow_none=True,
+            desc="should_update_remote_profile_cache_for_user",
+        )
+
+        if res:
+            defer.returnValue(True)
+
+        res = yield self._simple_select_one_onecol(
+            table="group_invites",
+            keyvalues={"user_id": user_id},
+            retcol="user_id",
+            allow_none=True,
+            desc="should_update_remote_profile_cache_for_user",
+        )
+
+        if res:
+            defer.returnValue(True)
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index cbec255966..04a0b59a39 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,10 +15,17 @@
 # limitations under the License.
 
 from ._base import SQLBaseStore
+from synapse.storage.appservice import ApplicationServiceWorkerStore
+from synapse.storage.pusher import PusherWorkerStore
+from synapse.storage.receipts import ReceiptsWorkerStore
+from synapse.storage.roommember import RoomMemberWorkerStore
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.push.baserules import list_with_base_rules
+from synapse.api.constants import EventTypes
 from twisted.internet import defer
 
+import abc
 import logging
 import simplejson as json
 
@@ -47,8 +55,44 @@ def _load_rules(rawrules, enabled_map):
     return rules
 
 
-class PushRuleStore(SQLBaseStore):
-    @cachedInlineCallbacks()
+class PushRulesWorkerStore(ApplicationServiceWorkerStore,
+                           ReceiptsWorkerStore,
+                           PusherWorkerStore,
+                           RoomMemberWorkerStore,
+                           SQLBaseStore):
+    """This is an abstract base class where subclasses must implement
+    `get_max_push_rules_stream_id` which can be called in the initializer.
+    """
+
+    # This ABCMeta metaclass ensures that we cannot be instantiated without
+    # the abstract methods being implemented.
+    __metaclass__ = abc.ABCMeta
+
+    def __init__(self, db_conn, hs):
+        super(PushRulesWorkerStore, self).__init__(db_conn, hs)
+
+        push_rules_prefill, push_rules_id = self._get_cache_dict(
+            db_conn, "push_rules_stream",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=self.get_max_push_rules_stream_id(),
+        )
+
+        self.push_rules_stream_cache = StreamChangeCache(
+            "PushRulesStreamChangeCache", push_rules_id,
+            prefilled_cache=push_rules_prefill,
+        )
+
+    @abc.abstractmethod
+    def get_max_push_rules_stream_id(self):
+        """Get the position of the push rules stream.
+
+        Returns:
+            int
+        """
+        raise NotImplementedError()
+
+    @cachedInlineCallbacks(max_entries=5000)
     def get_push_rules_for_user(self, user_id):
         rows = yield self._simple_select_list(
             table="push_rules",
@@ -72,7 +116,7 @@ class PushRuleStore(SQLBaseStore):
 
         defer.returnValue(rules)
 
-    @cachedInlineCallbacks()
+    @cachedInlineCallbacks(max_entries=5000)
     def get_push_rules_enabled_for_user(self, user_id):
         results = yield self._simple_select_list(
             table="push_rules_enable",
@@ -88,6 +132,22 @@ class PushRuleStore(SQLBaseStore):
             r['rule_id']: False if r['enabled'] == 0 else True for r in results
         })
 
+    def have_push_rules_changed_for_user(self, user_id, last_id):
+        if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
+            return defer.succeed(False)
+        else:
+            def have_push_rules_changed_txn(txn):
+                sql = (
+                    "SELECT COUNT(stream_id) FROM push_rules_stream"
+                    " WHERE user_id = ? AND ? < stream_id"
+                )
+                txn.execute(sql, (user_id, last_id))
+                count, = txn.fetchone()
+                return bool(count)
+            return self.runInteraction(
+                "have_push_rules_changed", have_push_rules_changed_txn
+            )
+
     @cachedList(cached_method_name="get_push_rules_for_user",
                 list_name="user_ids", num_args=1, inlineCallbacks=True)
     def bulk_get_push_rules(self, user_ids):
@@ -184,6 +244,18 @@ class PushRuleStore(SQLBaseStore):
             if uid in local_users_in_room:
                 user_ids.add(uid)
 
+        forgotten = yield self.who_forgot_in_room(
+            event.room_id, on_invalidate=cache_context.invalidate,
+        )
+
+        for row in forgotten:
+            user_id = row["user_id"]
+            event_id = row["event_id"]
+
+            mem_id = current_state_ids.get((EventTypes.Member, user_id), None)
+            if event_id == mem_id:
+                user_ids.discard(user_id)
+
         rules_by_user = yield self.bulk_get_push_rules(
             user_ids, on_invalidate=cache_context.invalidate,
         )
@@ -215,6 +287,8 @@ class PushRuleStore(SQLBaseStore):
             results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
         defer.returnValue(results)
 
+
+class PushRuleStore(PushRulesWorkerStore):
     @defer.inlineCallbacks
     def add_push_rule(
         self, user_id, rule_id, priority_class, conditions, actions,
@@ -513,21 +587,8 @@ class PushRuleStore(SQLBaseStore):
         room stream ordering it corresponds to."""
         return self._push_rules_stream_id_gen.get_current_token()
 
-    def have_push_rules_changed_for_user(self, user_id, last_id):
-        if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
-            return defer.succeed(False)
-        else:
-            def have_push_rules_changed_txn(txn):
-                sql = (
-                    "SELECT COUNT(stream_id) FROM push_rules_stream"
-                    " WHERE user_id = ? AND ? < stream_id"
-                )
-                txn.execute(sql, (user_id, last_id))
-                count, = txn.fetchone()
-                return bool(count)
-            return self.runInteraction(
-                "have_push_rules_changed", have_push_rules_changed_txn
-            )
+    def get_max_push_rules_stream_id(self):
+        return self.get_push_rules_stream_token()[0]
 
 
 class RuleNotFoundException(Exception):
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 8cc9f0353b..307660b99a 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -27,7 +28,7 @@ import types
 logger = logging.getLogger(__name__)
 
 
-class PusherStore(SQLBaseStore):
+class PusherWorkerStore(SQLBaseStore):
     def _decode_pushers_rows(self, rows):
         for r in rows:
             dataJson = r['data']
@@ -102,9 +103,6 @@ class PusherStore(SQLBaseStore):
         rows = yield self.runInteraction("get_all_pushers", get_pushers)
         defer.returnValue(rows)
 
-    def get_pushers_stream_token(self):
-        return self._pushers_id_gen.get_current_token()
-
     def get_all_updated_pushers(self, last_id, current_id, limit):
         if last_id == current_id:
             return defer.succeed(([], []))
@@ -135,6 +133,48 @@ class PusherStore(SQLBaseStore):
             "get_all_updated_pushers", get_all_updated_pushers_txn
         )
 
+    def get_all_updated_pushers_rows(self, last_id, current_id, limit):
+        """Get all the pushers that have changed between the given tokens.
+
+        Returns:
+            Deferred(list(tuple)): each tuple consists of:
+                stream_id (str)
+                user_id (str)
+                app_id (str)
+                pushkey (str)
+                was_deleted (bool): whether the pusher was added/updated (False)
+                    or deleted (True)
+        """
+
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_updated_pushers_rows_txn(txn):
+            sql = (
+                "SELECT id, user_name, app_id, pushkey"
+                " FROM pushers"
+                " WHERE ? < id AND id <= ?"
+                " ORDER BY id ASC LIMIT ?"
+            )
+            txn.execute(sql, (last_id, current_id, limit))
+            results = [list(row) + [False] for row in txn]
+
+            sql = (
+                "SELECT stream_id, user_id, app_id, pushkey"
+                " FROM deleted_pushers"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " ORDER BY stream_id ASC LIMIT ?"
+            )
+            txn.execute(sql, (last_id, current_id, limit))
+
+            results.extend(list(row) + [True] for row in txn)
+            results.sort()  # Sort so that they're ordered by stream id
+
+            return results
+        return self.runInteraction(
+            "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
+        )
+
     @cachedInlineCallbacks(num_args=1, max_entries=15000)
     def get_if_user_has_pusher(self, user_id):
         # This only exists for the cachedList decorator
@@ -156,56 +196,74 @@ class PusherStore(SQLBaseStore):
 
         defer.returnValue(result)
 
+
+class PusherStore(PusherWorkerStore):
+    def get_pushers_stream_token(self):
+        return self._pushers_id_gen.get_current_token()
+
     @defer.inlineCallbacks
     def add_pusher(self, user_id, access_token, kind, app_id,
                    app_display_name, device_display_name,
                    pushkey, pushkey_ts, lang, data, last_stream_ordering,
                    profile_tag=""):
         with self._pushers_id_gen.get_next() as stream_id:
-            def f(txn):
-                newly_inserted = self._simple_upsert_txn(
-                    txn,
-                    "pushers",
-                    {
-                        "app_id": app_id,
-                        "pushkey": pushkey,
-                        "user_name": user_id,
-                    },
-                    {
-                        "access_token": access_token,
-                        "kind": kind,
-                        "app_display_name": app_display_name,
-                        "device_display_name": device_display_name,
-                        "ts": pushkey_ts,
-                        "lang": lang,
-                        "data": encode_canonical_json(data),
-                        "last_stream_ordering": last_stream_ordering,
-                        "profile_tag": profile_tag,
-                        "id": stream_id,
-                    },
-                )
-                if newly_inserted:
-                    # get_if_user_has_pusher only cares if the user has
-                    # at least *one* pusher.
-                    txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
+            # no need to lock because `pushers` has a unique key on
+            # (app_id, pushkey, user_name) so _simple_upsert will retry
+            newly_inserted = yield self._simple_upsert(
+                table="pushers",
+                keyvalues={
+                    "app_id": app_id,
+                    "pushkey": pushkey,
+                    "user_name": user_id,
+                },
+                values={
+                    "access_token": access_token,
+                    "kind": kind,
+                    "app_display_name": app_display_name,
+                    "device_display_name": device_display_name,
+                    "ts": pushkey_ts,
+                    "lang": lang,
+                    "data": encode_canonical_json(data),
+                    "last_stream_ordering": last_stream_ordering,
+                    "profile_tag": profile_tag,
+                    "id": stream_id,
+                },
+                desc="add_pusher",
+                lock=False,
+            )
 
-            yield self.runInteraction("add_pusher", f)
+            if newly_inserted:
+                self.runInteraction(
+                    "add_pusher",
+                    self._invalidate_cache_and_stream,
+                    self.get_if_user_has_pusher, (user_id,)
+                )
 
     @defer.inlineCallbacks
     def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
         def delete_pusher_txn(txn, stream_id):
-            txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
+            self._invalidate_cache_and_stream(
+                txn, self.get_if_user_has_pusher, (user_id,)
+            )
 
             self._simple_delete_one_txn(
                 txn,
                 "pushers",
                 {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}
             )
-            self._simple_upsert_txn(
+
+            # it's possible for us to end up with duplicate rows for
+            # (app_id, pushkey, user_id) at different stream_ids, but that
+            # doesn't really matter.
+            self._simple_insert_txn(
                 txn,
-                "deleted_pushers",
-                {"app_id": app_id, "pushkey": pushkey, "user_id": user_id},
-                {"stream_id": stream_id},
+                table="deleted_pushers",
+                values={
+                    "stream_id": stream_id,
+                    "app_id": app_id,
+                    "pushkey": pushkey,
+                    "user_id": user_id,
+                },
             )
 
         with self._pushers_id_gen.get_next() as stream_id:
@@ -268,9 +326,12 @@ class PusherStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def set_throttle_params(self, pusher_id, room_id, params):
+        # no need to lock because `pusher_throttle` has a primary key on
+        # (pusher, room_id) so _simple_upsert will retry
         yield self._simple_upsert(
             "pusher_throttle",
             {"pusher": pusher_id, "room_id": room_id},
             params,
-            desc="set_throttle_params"
+            desc="set_throttle_params",
+            lock=False,
         )
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index f72d15f5ed..63997ed449 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,46 +15,50 @@
 # limitations under the License.
 
 from ._base import SQLBaseStore
+from .util.id_generators import StreamIdGenerator
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from twisted.internet import defer
 
+import abc
 import logging
-import ujson as json
+import simplejson as json
 
 
 logger = logging.getLogger(__name__)
 
 
-class ReceiptsStore(SQLBaseStore):
-    def __init__(self, hs):
-        super(ReceiptsStore, self).__init__(hs)
+class ReceiptsWorkerStore(SQLBaseStore):
+    """This is an abstract base class where subclasses must implement
+    `get_max_receipt_stream_id` which can be called in the initializer.
+    """
+
+    # This ABCMeta metaclass ensures that we cannot be instantiated without
+    # the abstract methods being implemented.
+    __metaclass__ = abc.ABCMeta
+
+    def __init__(self, db_conn, hs):
+        super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
 
         self._receipts_stream_cache = StreamChangeCache(
-            "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
+            "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
         )
 
+    @abc.abstractmethod
+    def get_max_receipt_stream_id(self):
+        """Get the current max stream ID for receipts stream
+
+        Returns:
+            int
+        """
+        raise NotImplementedError()
+
     @cachedInlineCallbacks()
     def get_users_with_read_receipts_in_room(self, room_id):
         receipts = yield self.get_receipts_for_room(room_id, "m.read")
         defer.returnValue(set(r['user_id'] for r in receipts))
 
-    def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
-                                                    user_id):
-        if receipt_type != "m.read":
-            return
-
-        # Returns an ObservableDeferred
-        res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None)
-
-        if res and res.called and user_id in res.result:
-            # We'd only be adding to the set, so no point invalidating if the
-            # user is already there
-            return
-
-        self.get_users_with_read_receipts_in_room.invalidate((room_id,))
-
     @cached(num_args=2)
     def get_receipts_for_room(self, room_id, receipt_type):
         return self._simple_select_list(
@@ -265,6 +270,59 @@ class ReceiptsStore(SQLBaseStore):
         }
         defer.returnValue(results)
 
+    def get_all_updated_receipts(self, last_id, current_id, limit=None):
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_updated_receipts_txn(txn):
+            sql = (
+                "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
+                " FROM receipts_linearized"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " ORDER BY stream_id ASC"
+            )
+            args = [last_id, current_id]
+            if limit is not None:
+                sql += " LIMIT ?"
+                args.append(limit)
+            txn.execute(sql, args)
+
+            return txn.fetchall()
+        return self.runInteraction(
+            "get_all_updated_receipts", get_all_updated_receipts_txn
+        )
+
+    def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
+                                                    user_id):
+        if receipt_type != "m.read":
+            return
+
+        # Returns an ObservableDeferred
+        res = self.get_users_with_read_receipts_in_room.cache.get(
+            room_id, None, update_metrics=False,
+        )
+
+        if res:
+            if isinstance(res, defer.Deferred) and res.called:
+                res = res.result
+            if user_id in res:
+                # We'd only be adding to the set, so no point invalidating if the
+                # user is already there
+                return
+
+        self.get_users_with_read_receipts_in_room.invalidate((room_id,))
+
+
+class ReceiptsStore(ReceiptsWorkerStore):
+    def __init__(self, db_conn, hs):
+        # We instantiate this first as the ReceiptsWorkerStore constructor
+        # needs to be able to call get_max_receipt_stream_id
+        self._receipts_id_gen = StreamIdGenerator(
+            db_conn, "receipts_linearized", "stream_id"
+        )
+
+        super(ReceiptsStore, self).__init__(db_conn, hs)
+
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_current_token()
 
@@ -313,10 +371,9 @@ class ReceiptsStore(SQLBaseStore):
         )
 
         txn.execute(sql, (room_id, receipt_type, user_id))
-        results = txn.fetchall()
 
-        if results and topological_ordering:
-            for to, so, _ in results:
+        if topological_ordering:
+            for to, so, _ in txn:
                 if int(to) > topological_ordering:
                     return False
                 elif int(to) == topological_ordering and int(so) >= stream_ordering:
@@ -351,6 +408,7 @@ class ReceiptsStore(SQLBaseStore):
                 room_id=room_id,
                 user_id=user_id,
                 topological_ordering=topological_ordering,
+                stream_ordering=stream_ordering,
             )
 
         return True
@@ -452,25 +510,3 @@ class ReceiptsStore(SQLBaseStore):
                 "data": json.dumps(data),
             }
         )
-
-    def get_all_updated_receipts(self, last_id, current_id, limit=None):
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_updated_receipts_txn(txn):
-            sql = (
-                "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
-                " FROM receipts_linearized"
-                " WHERE ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC"
-            )
-            args = [last_id, current_id]
-            if limit is not None:
-                sql += " LIMIT ?"
-                args.append(limit)
-            txn.execute(sql, args)
-
-            return txn.fetchall()
-        return self.runInteraction(
-            "get_all_updated_receipts", get_all_updated_receipts_txn
-        )
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 26be6060c3..a50717db2d 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -19,13 +19,75 @@ from twisted.internet import defer
 
 from synapse.api.errors import StoreError, Codes
 from synapse.storage import background_updates
+from synapse.storage._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
+from six.moves import range
 
-class RegistrationStore(background_updates.BackgroundUpdateStore):
 
-    def __init__(self, hs):
-        super(RegistrationStore, self).__init__(hs)
+class RegistrationWorkerStore(SQLBaseStore):
+    @cached()
+    def get_user_by_id(self, user_id):
+        return self._simple_select_one(
+            table="users",
+            keyvalues={
+                "name": user_id,
+            },
+            retcols=["name", "password_hash", "is_guest"],
+            allow_none=True,
+            desc="get_user_by_id",
+        )
+
+    @cached()
+    def get_user_by_access_token(self, token):
+        """Get a user from the given access token.
+
+        Args:
+            token (str): The access token of a user.
+        Returns:
+            defer.Deferred: None, if the token did not match, otherwise dict
+                including the keys `name`, `is_guest`, `device_id`, `token_id`.
+        """
+        return self.runInteraction(
+            "get_user_by_access_token",
+            self._query_for_auth,
+            token
+        )
+
+    @defer.inlineCallbacks
+    def is_server_admin(self, user):
+        res = yield self._simple_select_one_onecol(
+            table="users",
+            keyvalues={"name": user.to_string()},
+            retcol="admin",
+            allow_none=True,
+            desc="is_server_admin",
+        )
+
+        defer.returnValue(res if res else False)
+
+    def _query_for_auth(self, txn, token):
+        sql = (
+            "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
+            " access_tokens.device_id"
+            " FROM users"
+            " INNER JOIN access_tokens on users.name = access_tokens.user_id"
+            " WHERE token = ?"
+        )
+
+        txn.execute(sql, (token,))
+        rows = self.cursor_to_dict(txn)
+        if rows:
+            return rows[0]
+
+        return None
+
+
+class RegistrationStore(RegistrationWorkerStore,
+                        background_updates.BackgroundUpdateStore):
+
+    def __init__(self, db_conn, hs):
+        super(RegistrationStore, self).__init__(db_conn, hs)
 
         self.clock = hs.get_clock()
 
@@ -36,12 +98,10 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
             columns=["user_id", "device_id"],
         )
 
-        self.register_background_index_update(
-            "refresh_tokens_device_index",
-            index_name="refresh_tokens_device_id",
-            table="refresh_tokens",
-            columns=["user_id", "device_id"],
-        )
+        # we no longer use refresh tokens, but it's possible that some people
+        # might have a background update queued to build this index. Just
+        # clear the background update.
+        self.register_noop_background_update("refresh_tokens_device_index")
 
     @defer.inlineCallbacks
     def add_access_token_to_user(self, user_id, token, device_id=None):
@@ -177,9 +237,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
             )
 
         if create_profile_with_localpart:
+            # set a default displayname serverside to avoid ugly race
+            # between auto-joins and clients trying to set displaynames
             txn.execute(
-                "INSERT INTO profiles(user_id) VALUES (?)",
-                (create_profile_with_localpart,)
+                "INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
+                (create_profile_with_localpart, create_profile_with_localpart)
             )
 
         self._invalidate_cache_and_stream(
@@ -187,18 +249,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
         )
         txn.call_after(self.is_guest.invalidate, (user_id,))
 
-    @cached()
-    def get_user_by_id(self, user_id):
-        return self._simple_select_one(
-            table="users",
-            keyvalues={
-                "name": user_id,
-            },
-            retcols=["name", "password_hash", "is_guest"],
-            allow_none=True,
-            desc="get_user_by_id",
-        )
-
     def get_users_by_id_case_insensitive(self, user_id):
         """Gets users that match user_id case insensitively.
         Returns a mapping of user_id -> password_hash.
@@ -209,7 +259,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
                 " WHERE lower(name) = lower(?)"
             )
             txn.execute(sql, (user_id,))
-            return dict(txn.fetchall())
+            return dict(txn)
 
         return self.runInteraction("get_users_by_id_case_insensitive", f)
 
@@ -236,12 +286,10 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
             "user_set_password_hash", user_set_password_hash_txn
         )
 
-    @defer.inlineCallbacks
     def user_delete_access_tokens(self, user_id, except_token_id=None,
-                                  device_id=None,
-                                  delete_refresh_tokens=False):
+                                  device_id=None):
         """
-        Invalidate access/refresh tokens belonging to a user
+        Invalidate access tokens belonging to a user
 
         Args:
             user_id (str):  ID of user the tokens belong to
@@ -250,10 +298,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
             device_id (str|None):  ID of device the tokens are associated with.
                 If None, tokens associated with any device (or no device) will
                 be deleted
-            delete_refresh_tokens (bool):  True to delete refresh tokens as
-                well as access tokens.
         Returns:
-            defer.Deferred:
+            defer.Deferred[list[str, int, str|None, int]]: a list of
+                (token, token id, device id) for each of the deleted tokens
         """
         def f(txn):
             keyvalues = {
@@ -262,13 +309,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
             if device_id is not None:
                 keyvalues["device_id"] = device_id
 
-            if delete_refresh_tokens:
-                self._simple_delete_txn(
-                    txn,
-                    table="refresh_tokens",
-                    keyvalues=keyvalues,
-                )
-
             items = keyvalues.items()
             where_clause = " AND ".join(k + " = ?" for k, _ in items)
             values = [v for _, v in items]
@@ -277,14 +317,14 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
                 values.append(except_token_id)
 
             txn.execute(
-                "SELECT token FROM access_tokens WHERE %s" % where_clause,
+                "SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause,
                 values
             )
-            rows = self.cursor_to_dict(txn)
+            tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
 
-            for row in rows:
+            for token, _, _ in tokens_and_devices:
                 self._invalidate_cache_and_stream(
-                    txn, self.get_user_by_access_token, (row["token"],)
+                    txn, self.get_user_by_access_token, (token,)
                 )
 
             txn.execute(
@@ -292,7 +332,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
                 values
             )
 
-        yield self.runInteraction(
+            return tokens_and_devices
+
+        return self.runInteraction(
             "user_delete_access_tokens", f,
         )
 
@@ -312,34 +354,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
 
         return self.runInteraction("delete_access_token", f)
 
-    @cached()
-    def get_user_by_access_token(self, token):
-        """Get a user from the given access token.
-
-        Args:
-            token (str): The access token of a user.
-        Returns:
-            defer.Deferred: None, if the token did not match, otherwise dict
-                including the keys `name`, `is_guest`, `device_id`, `token_id`.
-        """
-        return self.runInteraction(
-            "get_user_by_access_token",
-            self._query_for_auth,
-            token
-        )
-
-    @defer.inlineCallbacks
-    def is_server_admin(self, user):
-        res = yield self._simple_select_one_onecol(
-            table="users",
-            keyvalues={"name": user.to_string()},
-            retcol="admin",
-            allow_none=True,
-            desc="is_server_admin",
-        )
-
-        defer.returnValue(res if res else False)
-
     @cachedInlineCallbacks()
     def is_guest(self, user_id):
         res = yield self._simple_select_one_onecol(
@@ -352,22 +366,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
 
         defer.returnValue(res if res else False)
 
-    def _query_for_auth(self, txn, token):
-        sql = (
-            "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
-            " access_tokens.device_id"
-            " FROM users"
-            " INNER JOIN access_tokens on users.name = access_tokens.user_id"
-            " WHERE token = ?"
-        )
-
-        txn.execute(sql, (token,))
-        rows = self.cursor_to_dict(txn)
-        if rows:
-            return rows[0]
-
-        return None
-
     @defer.inlineCallbacks
     def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
         yield self._simple_upsert("user_threepids", {
@@ -438,6 +436,19 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
         defer.returnValue(ret)
 
     @defer.inlineCallbacks
+    def count_nonbridged_users(self):
+        def _count_users(txn):
+            txn.execute("""
+                SELECT COALESCE(COUNT(*), 0) FROM users
+                WHERE appservice_id IS NULL
+            """)
+            count, = txn.fetchone()
+            return count
+
+        ret = yield self.runInteraction("count_users", _count_users)
+        defer.returnValue(ret)
+
+    @defer.inlineCallbacks
     def find_next_generated_user_id_localpart(self):
         """
         Gets the localpart of the next generated user ID.
@@ -451,18 +462,16 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
         """
         def _find_next_generated_user_id(txn):
             txn.execute("SELECT name FROM users")
-            rows = self.cursor_to_dict(txn)
 
             regex = re.compile("^@(\d+):")
 
             found = set()
 
-            for r in rows:
-                user_id = r["name"]
+            for user_id, in txn:
                 match = regex.search(user_id)
                 if match:
                     found.add(int(match.group(1)))
-            for i in xrange(len(found) + 1):
+            for i in range(len(found) + 1):
                 if i not in found:
                     return i
 
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 8a2fe2fdf5..ea6a189185 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -16,14 +16,14 @@
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
-from synapse.util.caches.descriptors import cached
-
-from ._base import SQLBaseStore
-from .engines import PostgresEngine, Sqlite3Engine
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.search import SearchStore
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
 import collections
 import logging
-import ujson as json
+import simplejson as json
+import re
 
 logger = logging.getLogger(__name__)
 
@@ -33,8 +33,144 @@ OpsLevel = collections.namedtuple(
     ("ban_level", "kick_level", "redact_level",)
 )
 
+RatelimitOverride = collections.namedtuple(
+    "RatelimitOverride",
+    ("messages_per_second", "burst_count",)
+)
+
 
-class RoomStore(SQLBaseStore):
+class RoomWorkerStore(SQLBaseStore):
+    def get_public_room_ids(self):
+        return self._simple_select_onecol(
+            table="rooms",
+            keyvalues={
+                "is_public": True,
+            },
+            retcol="room_id",
+            desc="get_public_room_ids",
+        )
+
+    @cached(num_args=2, max_entries=100)
+    def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
+        """Get pulbic rooms for a particular list, or across all lists.
+
+        Args:
+            stream_id (int)
+            network_tuple (ThirdPartyInstanceID): The list to use (None, None)
+                means the main list, None means all lsits.
+        """
+        return self.runInteraction(
+            "get_public_room_ids_at_stream_id",
+            self.get_public_room_ids_at_stream_id_txn,
+            stream_id, network_tuple=network_tuple
+        )
+
+    def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
+                                             network_tuple):
+        return {
+            rm
+            for rm, vis in self.get_published_at_stream_id_txn(
+                txn, stream_id, network_tuple=network_tuple
+            ).items()
+            if vis
+        }
+
+    def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
+        if network_tuple:
+            # We want to get from a particular list. No aggregation required.
+
+            sql = ("""
+                SELECT room_id, visibility FROM public_room_list_stream
+                INNER JOIN (
+                    SELECT room_id, max(stream_id) AS stream_id
+                    FROM public_room_list_stream
+                    WHERE stream_id <= ? %s
+                    GROUP BY room_id
+                ) grouped USING (room_id, stream_id)
+            """)
+
+            if network_tuple.appservice_id is not None:
+                txn.execute(
+                    sql % ("AND appservice_id = ? AND network_id = ?",),
+                    (stream_id, network_tuple.appservice_id, network_tuple.network_id,)
+                )
+            else:
+                txn.execute(
+                    sql % ("AND appservice_id IS NULL",),
+                    (stream_id,)
+                )
+            return dict(txn)
+        else:
+            # We want to get from all lists, so we need to aggregate the results
+
+            logger.info("Executing full list")
+
+            sql = ("""
+                SELECT room_id, visibility
+                FROM public_room_list_stream
+                INNER JOIN (
+                    SELECT
+                        room_id, max(stream_id) AS stream_id, appservice_id,
+                        network_id
+                    FROM public_room_list_stream
+                    WHERE stream_id <= ?
+                    GROUP BY room_id, appservice_id, network_id
+                ) grouped USING (room_id, stream_id)
+            """)
+
+            txn.execute(
+                sql,
+                (stream_id,)
+            )
+
+            results = {}
+            # A room is visible if its visible on any list.
+            for room_id, visibility in txn:
+                results[room_id] = bool(visibility) or results.get(room_id, False)
+
+            return results
+
+    def get_public_room_changes(self, prev_stream_id, new_stream_id,
+                                network_tuple):
+        def get_public_room_changes_txn(txn):
+            then_rooms = self.get_public_room_ids_at_stream_id_txn(
+                txn, prev_stream_id, network_tuple
+            )
+
+            now_rooms_dict = self.get_published_at_stream_id_txn(
+                txn, new_stream_id, network_tuple
+            )
+
+            now_rooms_visible = set(
+                rm for rm, vis in now_rooms_dict.items() if vis
+            )
+            now_rooms_not_visible = set(
+                rm for rm, vis in now_rooms_dict.items() if not vis
+            )
+
+            newly_visible = now_rooms_visible - then_rooms
+            newly_unpublished = now_rooms_not_visible & then_rooms
+
+            return newly_visible, newly_unpublished
+
+        return self.runInteraction(
+            "get_public_room_changes", get_public_room_changes_txn
+        )
+
+    @cached(max_entries=10000)
+    def is_room_blocked(self, room_id):
+        return self._simple_select_one_onecol(
+            table="blocked_rooms",
+            keyvalues={
+                "room_id": room_id,
+            },
+            retcol="1",
+            allow_none=True,
+            desc="is_room_blocked",
+        )
+
+
+class RoomStore(RoomWorkerStore, SearchStore):
 
     @defer.inlineCallbacks
     def store_room(self, room_id, room_creator_user_id, is_public):
@@ -221,16 +357,6 @@ class RoomStore(SQLBaseStore):
             )
         self.hs.get_notifier().on_new_replication_data()
 
-    def get_public_room_ids(self):
-        return self._simple_select_onecol(
-            table="rooms",
-            keyvalues={
-                "is_public": True,
-            },
-            retcol="room_id",
-            desc="get_public_room_ids",
-        )
-
     def get_room_count(self):
         """Retrieve a list of all rooms
         """
@@ -257,8 +383,8 @@ class RoomStore(SQLBaseStore):
                 },
             )
 
-            self._store_event_search_txn(
-                txn, event, "content.topic", event.content["topic"]
+            self.store_event_search_txn(
+                txn, event, "content.topic", event.content["topic"],
             )
 
     def _store_room_name_txn(self, txn, event):
@@ -273,14 +399,14 @@ class RoomStore(SQLBaseStore):
                 }
             )
 
-            self._store_event_search_txn(
-                txn, event, "content.name", event.content["name"]
+            self.store_event_search_txn(
+                txn, event, "content.name", event.content["name"],
             )
 
     def _store_room_message_txn(self, txn, event):
         if hasattr(event, "content") and "body" in event.content:
-            self._store_event_search_txn(
-                txn, event, "content.body", event.content["body"]
+            self.store_event_search_txn(
+                txn, event, "content.body", event.content["body"],
             )
 
     def _store_history_visibility_txn(self, txn, event):
@@ -302,31 +428,6 @@ class RoomStore(SQLBaseStore):
                 event.content[key]
             ))
 
-    def _store_event_search_txn(self, txn, event, key, value):
-        if isinstance(self.database_engine, PostgresEngine):
-            sql = (
-                "INSERT INTO event_search"
-                " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
-                " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
-            )
-            txn.execute(
-                sql,
-                (
-                    event.event_id, event.room_id, key, value,
-                    event.internal_metadata.stream_ordering,
-                    event.origin_server_ts,
-                )
-            )
-        elif isinstance(self.database_engine, Sqlite3Engine):
-            sql = (
-                "INSERT INTO event_search (event_id, room_id, key, value)"
-                " VALUES (?,?,?,?)"
-            )
-            txn.execute(sql, (event.event_id, event.room_id, key, value,))
-        else:
-            # This should be unreachable.
-            raise Exception("Unrecognized database engine")
-
     def add_event_report(self, room_id, event_id, user_id, reason, content,
                          received_ts):
         next_id = self._event_reports_id_gen.get_next()
@@ -347,129 +448,180 @@ class RoomStore(SQLBaseStore):
     def get_current_public_room_stream_id(self):
         return self._public_room_id_gen.get_current_token()
 
-    @cached(num_args=2, max_entries=100)
-    def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
-        """Get pulbic rooms for a particular list, or across all lists.
+    def get_all_new_public_rooms(self, prev_id, current_id, limit):
+        def get_all_new_public_rooms(txn):
+            sql = ("""
+                SELECT stream_id, room_id, visibility, appservice_id, network_id
+                FROM public_room_list_stream
+                WHERE stream_id > ? AND stream_id <= ?
+                ORDER BY stream_id ASC
+                LIMIT ?
+            """)
+
+            txn.execute(sql, (prev_id, current_id, limit,))
+            return txn.fetchall()
+
+        if prev_id == current_id:
+            return defer.succeed([])
 
-        Args:
-            stream_id (int)
-            network_tuple (ThirdPartyInstanceID): The list to use (None, None)
-                means the main list, None means all lsits.
-        """
         return self.runInteraction(
-            "get_public_room_ids_at_stream_id",
-            self.get_public_room_ids_at_stream_id_txn,
-            stream_id, network_tuple=network_tuple
+            "get_all_new_public_rooms", get_all_new_public_rooms
         )
 
-    def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
-                                             network_tuple):
-        return {
-            rm
-            for rm, vis in self.get_published_at_stream_id_txn(
-                txn, stream_id, network_tuple=network_tuple
-            ).items()
-            if vis
-        }
+    @cachedInlineCallbacks(max_entries=10000)
+    def get_ratelimit_for_user(self, user_id):
+        """Check if there are any overrides for ratelimiting for the given
+        user
 
-    def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
-        if network_tuple:
-            # We want to get from a particular list. No aggregation required.
+        Args:
+            user_id (str)
 
-            sql = ("""
-                SELECT room_id, visibility FROM public_room_list_stream
-                INNER JOIN (
-                    SELECT room_id, max(stream_id) AS stream_id
-                    FROM public_room_list_stream
-                    WHERE stream_id <= ? %s
-                    GROUP BY room_id
-                ) grouped USING (room_id, stream_id)
-            """)
+        Returns:
+            RatelimitOverride if there is an override, else None. If the contents
+            of RatelimitOverride are None or 0 then ratelimitng has been
+            disabled for that user entirely.
+        """
+        row = yield self._simple_select_one(
+            table="ratelimit_override",
+            keyvalues={"user_id": user_id},
+            retcols=("messages_per_second", "burst_count"),
+            allow_none=True,
+            desc="get_ratelimit_for_user",
+        )
 
-            if network_tuple.appservice_id is not None:
-                txn.execute(
-                    sql % ("AND appservice_id = ? AND network_id = ?",),
-                    (stream_id, network_tuple.appservice_id, network_tuple.network_id,)
-                )
-            else:
-                txn.execute(
-                    sql % ("AND appservice_id IS NULL",),
-                    (stream_id,)
-                )
-            return dict(txn.fetchall())
+        if row:
+            defer.returnValue(RatelimitOverride(
+                messages_per_second=row["messages_per_second"],
+                burst_count=row["burst_count"],
+            ))
         else:
-            # We want to get from all lists, so we need to aggregate the results
+            defer.returnValue(None)
 
-            logger.info("Executing full list")
-
-            sql = ("""
-                SELECT room_id, visibility
-                FROM public_room_list_stream
-                INNER JOIN (
-                    SELECT
-                        room_id, max(stream_id) AS stream_id, appservice_id,
-                        network_id
-                    FROM public_room_list_stream
-                    WHERE stream_id <= ?
-                    GROUP BY room_id, appservice_id, network_id
-                ) grouped USING (room_id, stream_id)
-            """)
-
-            txn.execute(
-                sql,
-                (stream_id,)
-            )
-
-            results = {}
-            # A room is visible if its visible on any list.
-            for room_id, visibility in txn.fetchall():
-                results[room_id] = bool(visibility) or results.get(room_id, False)
-
-            return results
+    @defer.inlineCallbacks
+    def block_room(self, room_id, user_id):
+        yield self._simple_insert(
+            table="blocked_rooms",
+            values={
+                "room_id": room_id,
+                "user_id": user_id,
+            },
+            desc="block_room",
+        )
+        yield self.runInteraction(
+            "block_room_invalidation",
+            self._invalidate_cache_and_stream,
+            self.is_room_blocked, (room_id,),
+        )
 
-    def get_public_room_changes(self, prev_stream_id, new_stream_id,
-                                network_tuple):
-        def get_public_room_changes_txn(txn):
-            then_rooms = self.get_public_room_ids_at_stream_id_txn(
-                txn, prev_stream_id, network_tuple
-            )
+    def get_media_mxcs_in_room(self, room_id):
+        """Retrieves all the local and remote media MXC URIs in a given room
 
-            now_rooms_dict = self.get_published_at_stream_id_txn(
-                txn, new_stream_id, network_tuple
-            )
+        Args:
+            room_id (str)
 
-            now_rooms_visible = set(
-                rm for rm, vis in now_rooms_dict.items() if vis
-            )
-            now_rooms_not_visible = set(
-                rm for rm, vis in now_rooms_dict.items() if not vis
+        Returns:
+            The local and remote media as a lists of tuples where the key is
+            the hostname and the value is the media ID.
+        """
+        def _get_media_mxcs_in_room_txn(txn):
+            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+            local_media_mxcs = []
+            remote_media_mxcs = []
+
+            # Convert the IDs to MXC URIs
+            for media_id in local_mxcs:
+                local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id))
+            for hostname, media_id in remote_mxcs:
+                remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
+
+            return local_media_mxcs, remote_media_mxcs
+        return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
+
+    def quarantine_media_ids_in_room(self, room_id, quarantined_by):
+        """For a room loops through all events with media and quarantines
+        the associated media
+        """
+        def _quarantine_media_in_room_txn(txn):
+            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+            total_media_quarantined = 0
+
+            # Now update all the tables to set the quarantined_by flag
+
+            txn.executemany("""
+                UPDATE local_media_repository
+                SET quarantined_by = ?
+                WHERE media_id = ?
+            """, ((quarantined_by, media_id) for media_id in local_mxcs))
+
+            txn.executemany(
+                """
+                    UPDATE remote_media_cache
+                    SET quarantined_by = ?
+                    WHERE media_origin = ? AND media_id = ?
+                """,
+                (
+                    (quarantined_by, origin, media_id)
+                    for origin, media_id in remote_mxcs
+                )
             )
 
-            newly_visible = now_rooms_visible - then_rooms
-            newly_unpublished = now_rooms_not_visible & then_rooms
+            total_media_quarantined += len(local_mxcs)
+            total_media_quarantined += len(remote_mxcs)
 
-            return newly_visible, newly_unpublished
+            return total_media_quarantined
 
         return self.runInteraction(
-            "get_public_room_changes", get_public_room_changes_txn
+            "quarantine_media_in_room",
+            _quarantine_media_in_room_txn,
         )
 
-    def get_all_new_public_rooms(self, prev_id, current_id, limit):
-        def get_all_new_public_rooms(txn):
-            sql = ("""
-                SELECT stream_id, room_id, visibility, appservice_id, network_id
-                FROM public_room_list_stream
-                WHERE stream_id > ? AND stream_id <= ?
-                ORDER BY stream_id ASC
-                LIMIT ?
-            """)
-
-            txn.execute(sql, (prev_id, current_id, limit,))
-            return txn.fetchall()
+    def _get_media_mxcs_in_room_txn(self, txn, room_id):
+        """Retrieves all the local and remote media MXC URIs in a given room
 
-        if prev_id == current_id:
-            return defer.succeed([])
+        Args:
+            txn (cursor)
+            room_id (str)
 
-        return self.runInteraction(
-            "get_all_new_public_rooms", get_all_new_public_rooms
-        )
+        Returns:
+            The local and remote media as a lists of tuples where the key is
+            the hostname and the value is the media ID.
+        """
+        mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+
+        next_token = self.get_current_events_token() + 1
+        local_media_mxcs = []
+        remote_media_mxcs = []
+
+        while next_token:
+            sql = """
+                SELECT stream_ordering, json FROM events
+                JOIN event_json USING (room_id, event_id)
+                WHERE room_id = ?
+                    AND stream_ordering < ?
+                    AND contains_url = ? AND outlier = ?
+                ORDER BY stream_ordering DESC
+                LIMIT ?
+            """
+            txn.execute(sql, (room_id, next_token, True, False, 100))
+
+            next_token = None
+            for stream_ordering, content_json in txn:
+                next_token = stream_ordering
+                event_json = json.loads(content_json)
+                content = event_json["content"]
+                content_url = content.get("url")
+                thumbnail_url = content.get("info", {}).get("thumbnail_url")
+
+                for url in (content_url, thumbnail_url):
+                    if not url:
+                        continue
+                    matches = mxc_re.match(url)
+                    if matches:
+                        hostname = matches.group(1)
+                        media_id = matches.group(2)
+                        if hostname == self.hs.hostname:
+                            local_media_mxcs.append(media_id)
+                        else:
+                            remote_media_mxcs.append((hostname, media_id))
+
+        return local_media_mxcs, remote_media_mxcs
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 545d3d3a99..6a861943a2 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,14 +18,17 @@ from twisted.internet import defer
 
 from collections import namedtuple
 
-from ._base import SQLBaseStore
+from synapse.storage.events import EventsWorkerStore
+from synapse.util.async import Linearizer
+from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.stringutils import to_ascii
 
 from synapse.api.constants import Membership, EventTypes
 from synapse.types import get_domain_from_id
 
 import logging
-import ujson as json
+import simplejson as json
 
 logger = logging.getLogger(__name__)
 
@@ -34,112 +38,47 @@ RoomsForUser = namedtuple(
     ("room_id", "sender", "membership", "event_id", "stream_ordering")
 )
 
+GetRoomsForUserWithStreamOrdering = namedtuple(
+    "_GetRoomsForUserWithStreamOrdering",
+    ("room_id", "stream_ordering",)
+)
 
-_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
-
-
-class RoomMemberStore(SQLBaseStore):
-    def __init__(self, hs):
-        super(RoomMemberStore, self).__init__(hs)
-        self.register_background_update_handler(
-            _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
-        )
 
-    def _store_room_members_txn(self, txn, events, backfilled):
-        """Store a room member in the database.
-        """
-        self._simple_insert_many_txn(
-            txn,
-            table="room_memberships",
-            values=[
-                {
-                    "event_id": event.event_id,
-                    "user_id": event.state_key,
-                    "sender": event.user_id,
-                    "room_id": event.room_id,
-                    "membership": event.membership,
-                    "display_name": event.content.get("displayname", None),
-                    "avatar_url": event.content.get("avatar_url", None),
-                }
-                for event in events
-            ]
-        )
+# We store this using a namedtuple so that we save about 3x space over using a
+# dict.
+ProfileInfo = namedtuple(
+    "ProfileInfo", ("avatar_url", "display_name")
+)
 
-        for event in events:
-            txn.call_after(
-                self._membership_stream_cache.entity_has_changed,
-                event.state_key, event.internal_metadata.stream_ordering
-            )
-            txn.call_after(
-                self.get_invited_rooms_for_user.invalidate, (event.state_key,)
-            )
 
-            # We update the local_invites table only if the event is "current",
-            # i.e., its something that has just happened.
-            # The only current event that can also be an outlier is if its an
-            # invite that has come in across federation.
-            is_new_state = not backfilled and (
-                not event.internal_metadata.is_outlier()
-                or event.internal_metadata.is_invite_from_remote()
-            )
-            is_mine = self.hs.is_mine_id(event.state_key)
-            if is_new_state and is_mine:
-                if event.membership == Membership.INVITE:
-                    self._simple_insert_txn(
-                        txn,
-                        table="local_invites",
-                        values={
-                            "event_id": event.event_id,
-                            "invitee": event.state_key,
-                            "inviter": event.sender,
-                            "room_id": event.room_id,
-                            "stream_id": event.internal_metadata.stream_ordering,
-                        }
-                    )
-                else:
-                    sql = (
-                        "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
-                        " room_id = ? AND invitee = ? AND locally_rejected is NULL"
-                        " AND replaced_by is NULL"
-                    )
+_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
 
-                    txn.execute(sql, (
-                        event.internal_metadata.stream_ordering,
-                        event.event_id,
-                        event.room_id,
-                        event.state_key,
-                    ))
 
-    @defer.inlineCallbacks
-    def locally_reject_invite(self, user_id, room_id):
-        sql = (
-            "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
-            " room_id = ? AND invitee = ? AND locally_rejected is NULL"
-            " AND replaced_by is NULL"
+class RoomMemberWorkerStore(EventsWorkerStore):
+    @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
+    def get_hosts_in_room(self, room_id, cache_context):
+        """Returns the set of all hosts currently in the room
+        """
+        user_ids = yield self.get_users_in_room(
+            room_id, on_invalidate=cache_context.invalidate,
         )
+        hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
+        defer.returnValue(hosts)
 
-        def f(txn, stream_ordering):
-            txn.execute(sql, (
-                stream_ordering,
-                True,
-                room_id,
-                user_id,
-            ))
-
-        with self._stream_id_gen.get_next() as stream_ordering:
-            yield self.runInteraction("locally_reject_invite", f, stream_ordering)
-
-    @cached(max_entries=500000, iterable=True)
+    @cached(max_entries=100000, iterable=True)
     def get_users_in_room(self, room_id):
         def f(txn):
-
-            rows = self._get_members_rows_txn(
-                txn,
-                room_id=room_id,
-                membership=Membership.JOIN,
+            sql = (
+                "SELECT m.user_id FROM room_memberships as m"
+                " INNER JOIN current_state_events as c"
+                " ON m.event_id = c.event_id "
+                " AND m.room_id = c.room_id "
+                " AND m.user_id = c.state_key"
+                " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
             )
 
-            return [r["user_id"] for r in rows]
+            txn.execute(sql, (room_id, Membership.JOIN,))
+            return [to_ascii(r[0]) for r in txn]
         return self.runInteraction("get_users_in_room", f)
 
     @cached()
@@ -246,57 +185,382 @@ class RoomMemberStore(SQLBaseStore):
 
         return results
 
-    def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
-        where_clause = "c.room_id = ?"
-        where_values = [room_id]
-
-        if membership:
-            where_clause += " AND m.membership = ?"
-            where_values.append(membership)
+    @cachedInlineCallbacks(max_entries=500000, iterable=True)
+    def get_rooms_for_user_with_stream_ordering(self, user_id):
+        """Returns a set of room_ids the user is currently joined to
 
-        if user_id:
-            where_clause += " AND m.user_id = ?"
-            where_values.append(user_id)
+        Args:
+            user_id (str)
 
-        sql = (
-            "SELECT m.* FROM room_memberships as m"
-            " INNER JOIN current_state_events as c"
-            " ON m.event_id = c.event_id "
-            " AND m.room_id = c.room_id "
-            " AND m.user_id = c.state_key"
-            " WHERE c.type = 'm.room.member' AND %(where)s"
-        ) % {
-            "where": where_clause,
-        }
-
-        txn.execute(sql, where_values)
-        rows = self.cursor_to_dict(txn)
-
-        return rows
-
-    @cached(max_entries=500000, iterable=True)
-    def get_rooms_for_user(self, user_id):
-        return self.get_rooms_for_user_where_membership_is(
+        Returns:
+            Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
+            the rooms the user is in currently, along with the stream ordering
+            of the most recent join for that user and room.
+        """
+        rooms = yield self.get_rooms_for_user_where_membership_is(
             user_id, membership_list=[Membership.JOIN],
         )
+        defer.returnValue(frozenset(
+            GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
+            for r in rooms
+        ))
+
+    @defer.inlineCallbacks
+    def get_rooms_for_user(self, user_id, on_invalidate=None):
+        """Returns a set of room_ids the user is currently joined to
+        """
+        rooms = yield self.get_rooms_for_user_with_stream_ordering(
+            user_id, on_invalidate=on_invalidate,
+        )
+        defer.returnValue(frozenset(r.room_id for r in rooms))
 
     @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
     def get_users_who_share_room_with_user(self, user_id, cache_context):
         """Returns the set of users who share a room with `user_id`
         """
-        rooms = yield self.get_rooms_for_user(
+        room_ids = yield self.get_rooms_for_user(
             user_id, on_invalidate=cache_context.invalidate,
         )
 
         user_who_share_room = set()
-        for room in rooms:
+        for room_id in room_ids:
             user_ids = yield self.get_users_in_room(
-                room.room_id, on_invalidate=cache_context.invalidate,
+                room_id, on_invalidate=cache_context.invalidate,
             )
             user_who_share_room.update(user_ids)
 
         defer.returnValue(user_who_share_room)
 
+    def get_joined_users_from_context(self, event, context):
+        state_group = context.state_group
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        return self._get_joined_users_from_context(
+            event.room_id, state_group, context.current_state_ids,
+            event=event,
+            context=context,
+        )
+
+    def get_joined_users_from_state(self, room_id, state_entry):
+        state_group = state_entry.state_group
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        return self._get_joined_users_from_context(
+            room_id, state_group, state_entry.state, context=state_entry,
+        )
+
+    @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
+                           max_entries=100000)
+    def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
+                                       cache_context, event=None, context=None):
+        # We don't use `state_group`, it's there so that we can cache based
+        # on it. However, it's important that it's never None, since two current_states
+        # with a state_group of None are likely to be different.
+        # See bulk_get_push_rules_for_room for how we work around this.
+        assert state_group is not None
+
+        users_in_room = {}
+        member_event_ids = [
+            e_id
+            for key, e_id in current_state_ids.iteritems()
+            if key[0] == EventTypes.Member
+        ]
+
+        if context is not None:
+            # If we have a context with a delta from a previous state group,
+            # check if we also have the result from the previous group in cache.
+            # If we do then we can reuse that result and simply update it with
+            # any membership changes in `delta_ids`
+            if context.prev_group and context.delta_ids:
+                prev_res = self._get_joined_users_from_context.cache.get(
+                    (room_id, context.prev_group), None
+                )
+                if prev_res and isinstance(prev_res, dict):
+                    users_in_room = dict(prev_res)
+                    member_event_ids = [
+                        e_id
+                        for key, e_id in context.delta_ids.iteritems()
+                        if key[0] == EventTypes.Member
+                    ]
+                    for etype, state_key in context.delta_ids:
+                        users_in_room.pop(state_key, None)
+
+        # We check if we have any of the member event ids in the event cache
+        # before we ask the DB
+
+        # We don't update the event cache hit ratio as it completely throws off
+        # the hit ratio counts. After all, we don't populate the cache if we
+        # miss it here
+        event_map = self._get_events_from_cache(
+            member_event_ids,
+            allow_rejected=False,
+            update_metrics=False,
+        )
+
+        missing_member_event_ids = []
+        for event_id in member_event_ids:
+            ev_entry = event_map.get(event_id)
+            if ev_entry:
+                if ev_entry.event.membership == Membership.JOIN:
+                    users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo(
+                        display_name=to_ascii(
+                            ev_entry.event.content.get("displayname", None)
+                        ),
+                        avatar_url=to_ascii(
+                            ev_entry.event.content.get("avatar_url", None)
+                        ),
+                    )
+            else:
+                missing_member_event_ids.append(event_id)
+
+        if missing_member_event_ids:
+            rows = yield self._simple_select_many_batch(
+                table="room_memberships",
+                column="event_id",
+                iterable=missing_member_event_ids,
+                retcols=('user_id', 'display_name', 'avatar_url',),
+                keyvalues={
+                    "membership": Membership.JOIN,
+                },
+                batch_size=500,
+                desc="_get_joined_users_from_context",
+            )
+
+            users_in_room.update({
+                to_ascii(row["user_id"]): ProfileInfo(
+                    avatar_url=to_ascii(row["avatar_url"]),
+                    display_name=to_ascii(row["display_name"]),
+                )
+                for row in rows
+            })
+
+        if event is not None and event.type == EventTypes.Member:
+            if event.membership == Membership.JOIN:
+                if event.event_id in member_event_ids:
+                    users_in_room[to_ascii(event.state_key)] = ProfileInfo(
+                        display_name=to_ascii(event.content.get("displayname", None)),
+                        avatar_url=to_ascii(event.content.get("avatar_url", None)),
+                    )
+
+        defer.returnValue(users_in_room)
+
+    @cachedInlineCallbacks(max_entries=10000)
+    def is_host_joined(self, room_id, host):
+        if '%' in host or '_' in host:
+            raise Exception("Invalid host name")
+
+        sql = """
+            SELECT state_key FROM current_state_events AS c
+            INNER JOIN room_memberships USING (event_id)
+            WHERE membership = 'join'
+                AND type = 'm.room.member'
+                AND c.room_id = ?
+                AND state_key LIKE ?
+            LIMIT 1
+        """
+
+        # We do need to be careful to ensure that host doesn't have any wild cards
+        # in it, but we checked above for known ones and we'll check below that
+        # the returned user actually has the correct domain.
+        like_clause = "%:" + host
+
+        rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause)
+
+        if not rows:
+            defer.returnValue(False)
+
+        user_id = rows[0][0]
+        if get_domain_from_id(user_id) != host:
+            # This can only happen if the host name has something funky in it
+            raise Exception("Invalid host name")
+
+        defer.returnValue(True)
+
+    @cachedInlineCallbacks()
+    def was_host_joined(self, room_id, host):
+        """Check whether the server is or ever was in the room.
+
+        Args:
+            room_id (str)
+            host (str)
+
+        Returns:
+            Deferred: Resolves to True if the host is/was in the room, otherwise
+            False.
+        """
+        if '%' in host or '_' in host:
+            raise Exception("Invalid host name")
+
+        sql = """
+            SELECT user_id FROM room_memberships
+            WHERE room_id = ?
+                AND user_id LIKE ?
+                AND membership = 'join'
+            LIMIT 1
+        """
+
+        # We do need to be careful to ensure that host doesn't have any wild cards
+        # in it, but we checked above for known ones and we'll check below that
+        # the returned user actually has the correct domain.
+        like_clause = "%:" + host
+
+        rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause)
+
+        if not rows:
+            defer.returnValue(False)
+
+        user_id = rows[0][0]
+        if get_domain_from_id(user_id) != host:
+            # This can only happen if the host name has something funky in it
+            raise Exception("Invalid host name")
+
+        defer.returnValue(True)
+
+    def get_joined_hosts(self, room_id, state_entry):
+        state_group = state_entry.state_group
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        return self._get_joined_hosts(
+            room_id, state_group, state_entry.state, state_entry=state_entry,
+        )
+
+    @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
+    # @defer.inlineCallbacks
+    def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
+        # We don't use `state_group`, its there so that we can cache based
+        # on it. However, its important that its never None, since two current_state's
+        # with a state_group of None are likely to be different.
+        # See bulk_get_push_rules_for_room for how we work around this.
+        assert state_group is not None
+
+        cache = self._get_joined_hosts_cache(room_id)
+        joined_hosts = yield cache.get_destinations(state_entry)
+
+        defer.returnValue(joined_hosts)
+
+    @cached(max_entries=10000, iterable=True)
+    def _get_joined_hosts_cache(self, room_id):
+        return _JoinedHostsCache(self, room_id)
+
+    @cached()
+    def who_forgot_in_room(self, room_id):
+        return self._simple_select_list(
+            table="room_memberships",
+            retcols=("user_id", "event_id"),
+            keyvalues={
+                "room_id": room_id,
+                "forgotten": 1,
+            },
+            desc="who_forgot"
+        )
+
+
+class RoomMemberStore(RoomMemberWorkerStore):
+    def __init__(self, db_conn, hs):
+        super(RoomMemberStore, self).__init__(db_conn, hs)
+        self.register_background_update_handler(
+            _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
+        )
+
+    def _store_room_members_txn(self, txn, events, backfilled):
+        """Store a room member in the database.
+        """
+        self._simple_insert_many_txn(
+            txn,
+            table="room_memberships",
+            values=[
+                {
+                    "event_id": event.event_id,
+                    "user_id": event.state_key,
+                    "sender": event.user_id,
+                    "room_id": event.room_id,
+                    "membership": event.membership,
+                    "display_name": event.content.get("displayname", None),
+                    "avatar_url": event.content.get("avatar_url", None),
+                }
+                for event in events
+            ]
+        )
+
+        for event in events:
+            txn.call_after(
+                self._membership_stream_cache.entity_has_changed,
+                event.state_key, event.internal_metadata.stream_ordering
+            )
+            txn.call_after(
+                self.get_invited_rooms_for_user.invalidate, (event.state_key,)
+            )
+
+            # We update the local_invites table only if the event is "current",
+            # i.e., its something that has just happened.
+            # The only current event that can also be an outlier is if its an
+            # invite that has come in across federation.
+            is_new_state = not backfilled and (
+                not event.internal_metadata.is_outlier()
+                or event.internal_metadata.is_invite_from_remote()
+            )
+            is_mine = self.hs.is_mine_id(event.state_key)
+            if is_new_state and is_mine:
+                if event.membership == Membership.INVITE:
+                    self._simple_insert_txn(
+                        txn,
+                        table="local_invites",
+                        values={
+                            "event_id": event.event_id,
+                            "invitee": event.state_key,
+                            "inviter": event.sender,
+                            "room_id": event.room_id,
+                            "stream_id": event.internal_metadata.stream_ordering,
+                        }
+                    )
+                else:
+                    sql = (
+                        "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
+                        " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+                        " AND replaced_by is NULL"
+                    )
+
+                    txn.execute(sql, (
+                        event.internal_metadata.stream_ordering,
+                        event.event_id,
+                        event.room_id,
+                        event.state_key,
+                    ))
+
+    @defer.inlineCallbacks
+    def locally_reject_invite(self, user_id, room_id):
+        sql = (
+            "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
+            " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+            " AND replaced_by is NULL"
+        )
+
+        def f(txn, stream_ordering):
+            txn.execute(sql, (
+                stream_ordering,
+                True,
+                room_id,
+                user_id,
+            ))
+
+        with self._stream_id_gen.get_next() as stream_ordering:
+            yield self.runInteraction("locally_reject_invite", f, stream_ordering)
+
     def forget(self, user_id, room_id):
         """Indicate that user_id wishes to discard history for room_id."""
         def f(txn):
@@ -368,124 +632,6 @@ class RoomMemberStore(SQLBaseStore):
         forgot = yield self.runInteraction("did_forget_membership_at", f)
         defer.returnValue(forgot == 1)
 
-    @cached()
-    def who_forgot_in_room(self, room_id):
-        return self._simple_select_list(
-            table="room_memberships",
-            retcols=("user_id", "event_id"),
-            keyvalues={
-                "room_id": room_id,
-                "forgotten": 1,
-            },
-            desc="who_forgot"
-        )
-
-    def get_joined_users_from_context(self, event, context):
-        state_group = context.state_group
-        if not state_group:
-            # If state_group is None it means it has yet to be assigned a
-            # state group, i.e. we need to make sure that calls with a state_group
-            # of None don't hit previous cached calls with a None state_group.
-            # To do this we set the state_group to a new object as object() != object()
-            state_group = object()
-
-        return self._get_joined_users_from_context(
-            event.room_id, state_group, context.current_state_ids, event=event,
-        )
-
-    def get_joined_users_from_state(self, room_id, state_group, state_ids):
-        if not state_group:
-            # If state_group is None it means it has yet to be assigned a
-            # state group, i.e. we need to make sure that calls with a state_group
-            # of None don't hit previous cached calls with a None state_group.
-            # To do this we set the state_group to a new object as object() != object()
-            state_group = object()
-
-        return self._get_joined_users_from_context(
-            room_id, state_group, state_ids,
-        )
-
-    @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
-                           max_entries=100000)
-    def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
-                                       cache_context, event=None):
-        # We don't use `state_group`, it's there so that we can cache based
-        # on it. However, it's important that it's never None, since two current_states
-        # with a state_group of None are likely to be different.
-        # See bulk_get_push_rules_for_room for how we work around this.
-        assert state_group is not None
-
-        member_event_ids = [
-            e_id
-            for key, e_id in current_state_ids.iteritems()
-            if key[0] == EventTypes.Member
-        ]
-
-        rows = yield self._simple_select_many_batch(
-            table="room_memberships",
-            column="event_id",
-            iterable=member_event_ids,
-            retcols=['user_id', 'display_name', 'avatar_url'],
-            keyvalues={
-                "membership": Membership.JOIN,
-            },
-            batch_size=500,
-            desc="_get_joined_users_from_context",
-        )
-
-        users_in_room = {
-            row["user_id"]: {
-                "display_name": row["display_name"],
-                "avatar_url": row["avatar_url"],
-            }
-            for row in rows
-        }
-
-        if event is not None and event.type == EventTypes.Member:
-            if event.membership == Membership.JOIN:
-                if event.event_id in member_event_ids:
-                    users_in_room[event.state_key] = {
-                        "display_name": event.content.get("displayname", None),
-                        "avatar_url": event.content.get("avatar_url", None),
-                    }
-
-        defer.returnValue(users_in_room)
-
-    def is_host_joined(self, room_id, host, state_group, state_ids):
-        if not state_group:
-            # If state_group is None it means it has yet to be assigned a
-            # state group, i.e. we need to make sure that calls with a state_group
-            # of None don't hit previous cached calls with a None state_group.
-            # To do this we set the state_group to a new object as object() != object()
-            state_group = object()
-
-        return self._is_host_joined(
-            room_id, host, state_group, state_ids
-        )
-
-    @cachedInlineCallbacks(num_args=3)
-    def _is_host_joined(self, room_id, host, state_group, current_state_ids):
-        # We don't use `state_group`, its there so that we can cache based
-        # on it. However, its important that its never None, since two current_state's
-        # with a state_group of None are likely to be different.
-        # See bulk_get_push_rules_for_room for how we work around this.
-        assert state_group is not None
-
-        for (etype, state_key), event_id in current_state_ids.items():
-            if etype == EventTypes.Member:
-                try:
-                    if get_domain_from_id(state_key) != host:
-                        continue
-                except:
-                    logger.warn("state_key not user_id: %s", state_key)
-                    continue
-
-                event = yield self.get_event(event_id, allow_none=True)
-                if event and event.content["membership"] == Membership.JOIN:
-                    defer.returnValue(True)
-
-        defer.returnValue(False)
-
     @defer.inlineCallbacks
     def _background_add_membership_profile(self, progress, batch_size):
         target_min_stream_id = progress.get(
@@ -499,8 +645,9 @@ class RoomMemberStore(SQLBaseStore):
 
         def add_membership_profile_txn(txn):
             sql = ("""
-                SELECT stream_ordering, event_id, events.room_id, content
+                SELECT stream_ordering, event_id, events.room_id, event_json.json
                 FROM events
+                INNER JOIN event_json USING (event_id)
                 INNER JOIN room_memberships USING (event_id)
                 WHERE ? <= stream_ordering AND stream_ordering < ?
                 AND type = 'm.room.member'
@@ -521,8 +668,9 @@ class RoomMemberStore(SQLBaseStore):
                 event_id = row["event_id"]
                 room_id = row["room_id"]
                 try:
-                    content = json.loads(row["content"])
-                except:
+                    event_json = json.loads(row["json"])
+                    content = event_json['content']
+                except Exception:
                     continue
 
                 display_name = content.get("displayname", None)
@@ -560,3 +708,71 @@ class RoomMemberStore(SQLBaseStore):
             yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
 
         defer.returnValue(result)
+
+
+class _JoinedHostsCache(object):
+    """Cache for joined hosts in a room that is optimised to handle updates
+    via state deltas.
+    """
+
+    def __init__(self, store, room_id):
+        self.store = store
+        self.room_id = room_id
+
+        self.hosts_to_joined_users = {}
+
+        self.state_group = object()
+
+        self.linearizer = Linearizer("_JoinedHostsCache")
+
+        self._len = 0
+
+    @defer.inlineCallbacks
+    def get_destinations(self, state_entry):
+        """Get set of destinations for a state entry
+
+        Args:
+            state_entry(synapse.state._StateCacheEntry)
+        """
+        if state_entry.state_group == self.state_group:
+            defer.returnValue(frozenset(self.hosts_to_joined_users))
+
+        with (yield self.linearizer.queue(())):
+            if state_entry.state_group == self.state_group:
+                pass
+            elif state_entry.prev_group == self.state_group:
+                for (typ, state_key), event_id in state_entry.delta_ids.iteritems():
+                    if typ != EventTypes.Member:
+                        continue
+
+                    host = intern_string(get_domain_from_id(state_key))
+                    user_id = state_key
+                    known_joins = self.hosts_to_joined_users.setdefault(host, set())
+
+                    event = yield self.store.get_event(event_id)
+                    if event.membership == Membership.JOIN:
+                        known_joins.add(user_id)
+                    else:
+                        known_joins.discard(user_id)
+
+                        if not known_joins:
+                            self.hosts_to_joined_users.pop(host, None)
+            else:
+                joined_users = yield self.store.get_joined_users_from_state(
+                    self.room_id, state_entry,
+                )
+
+                self.hosts_to_joined_users = {}
+                for user_id in joined_users:
+                    host = intern_string(get_domain_from_id(user_id))
+                    self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
+
+            if state_entry.state_group:
+                self.state_group = state_entry.state_group
+            else:
+                self.state_group = object()
+            self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues())
+        defer.returnValue(frozenset(self.hosts_to_joined_users))
+
+    def __len__(self):
+        return self._len
diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
index 8755bb2e49..4d725b92fe 100644
--- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py
+++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
@@ -12,9 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import json
 import logging
 
+import simplejson as json
+
 logger = logging.getLogger(__name__)
 
 
diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py
index 4269ac69ad..e7351c3ae6 100644
--- a/synapse/storage/schema/delta/25/fts.py
+++ b/synapse/storage/schema/delta/25/fts.py
@@ -17,7 +17,7 @@ import logging
 from synapse.storage.prepare_database import get_statements
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
-import ujson
+import simplejson
 
 logger = logging.getLogger(__name__)
 
@@ -66,7 +66,7 @@ def run_create(cur, database_engine, *args, **kwargs):
             "max_stream_id_exclusive": max_stream_id + 1,
             "rows_inserted": 0,
         }
-        progress_json = ujson.dumps(progress)
+        progress_json = simplejson.dumps(progress)
 
         sql = (
             "INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/schema/delta/27/ts.py
index 71b12a2731..6df57b5206 100644
--- a/synapse/storage/schema/delta/27/ts.py
+++ b/synapse/storage/schema/delta/27/ts.py
@@ -16,7 +16,7 @@ import logging
 
 from synapse.storage.prepare_database import get_statements
 
-import ujson
+import simplejson
 
 logger = logging.getLogger(__name__)
 
@@ -45,7 +45,7 @@ def run_create(cur, database_engine, *args, **kwargs):
             "max_stream_id_exclusive": max_stream_id + 1,
             "rows_inserted": 0,
         }
-        progress_json = ujson.dumps(progress)
+        progress_json = simplejson.dumps(progress)
 
         sql = (
             "INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py
index 5b7d8d1ab5..85bd1a2006 100644
--- a/synapse/storage/schema/delta/30/as_users.py
+++ b/synapse/storage/schema/delta/30/as_users.py
@@ -14,6 +14,8 @@
 import logging
 from synapse.config.appservice import load_appservices
 
+from six.moves import range
+
 
 logger = logging.getLogger(__name__)
 
@@ -22,7 +24,7 @@ def run_create(cur, database_engine, *args, **kwargs):
     # NULL indicates user was not registered by an appservice.
     try:
         cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT")
-    except:
+    except Exception:
         # Maybe we already added the column? Hope so...
         pass
 
@@ -58,7 +60,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
 
     for as_id, user_ids in owned.items():
         n = 100
-        user_chunks = (user_ids[i:i + 100] for i in xrange(0, len(user_ids), n))
+        user_chunks = (user_ids[i:i + 100] for i in range(0, len(user_ids), n))
         for chunk in user_chunks:
             cur.execute(
                 database_engine.convert_param_style(
diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/schema/delta/31/search_update.py
index 470ae0c005..fe6b7d196d 100644
--- a/synapse/storage/schema/delta/31/search_update.py
+++ b/synapse/storage/schema/delta/31/search_update.py
@@ -16,7 +16,7 @@ from synapse.storage.engines import PostgresEngine
 from synapse.storage.prepare_database import get_statements
 
 import logging
-import ujson
+import simplejson
 
 logger = logging.getLogger(__name__)
 
@@ -49,7 +49,7 @@ def run_create(cur, database_engine, *args, **kwargs):
             "rows_inserted": 0,
             "have_added_indexes": False,
         }
-        progress_json = ujson.dumps(progress)
+        progress_json = simplejson.dumps(progress)
 
         sql = (
             "INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/schema/delta/33/event_fields.py
index 83066cccc9..1e002f9db2 100644
--- a/synapse/storage/schema/delta/33/event_fields.py
+++ b/synapse/storage/schema/delta/33/event_fields.py
@@ -15,7 +15,7 @@
 from synapse.storage.prepare_database import get_statements
 
 import logging
-import ujson
+import simplejson
 
 logger = logging.getLogger(__name__)
 
@@ -44,7 +44,7 @@ def run_create(cur, database_engine, *args, **kwargs):
             "max_stream_id_exclusive": max_stream_id + 1,
             "rows_inserted": 0,
         }
-        progress_json = ujson.dumps(progress)
+        progress_json = simplejson.dumps(progress)
 
         sql = (
             "INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/37/remove_auth_idx.py b/synapse/storage/schema/delta/37/remove_auth_idx.py
index 784f3b348f..20ad8bd5a6 100644
--- a/synapse/storage/schema/delta/37/remove_auth_idx.py
+++ b/synapse/storage/schema/delta/37/remove_auth_idx.py
@@ -36,6 +36,10 @@ DROP INDEX IF EXISTS transactions_have_ref;
 -- and is used incredibly rarely.
 DROP INDEX IF EXISTS events_order_topo_stream_room;
 
+-- an equivalent index to this actually gets re-created in delta 41, because it
+-- turned out that deleting it wasn't a great plan :/. In any case, let's
+-- delete it here, and delta 41 will create a new one with an added UNIQUE
+-- constraint
 DROP INDEX IF EXISTS event_search_ev_idx;
 """
 
diff --git a/synapse/storage/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/schema/delta/38/postgres_fts_gist.sql
index f090a7b75a..515e6b8e84 100644
--- a/synapse/storage/schema/delta/38/postgres_fts_gist.sql
+++ b/synapse/storage/schema/delta/38/postgres_fts_gist.sql
@@ -13,5 +13,7 @@
  * limitations under the License.
  */
 
- INSERT into background_updates (update_name, progress_json)
-     VALUES ('event_search_postgres_gist', '{}');
+-- We no longer do this given we back it out again in schema 47
+
+-- INSERT into background_updates (update_name, progress_json)
+--     VALUES ('event_search_postgres_gist', '{}');
diff --git a/synapse/storage/schema/delta/40/event_push_summary.sql b/synapse/storage/schema/delta/40/event_push_summary.sql
new file mode 100644
index 0000000000..3918f0b794
--- /dev/null
+++ b/synapse/storage/schema/delta/40/event_push_summary.sql
@@ -0,0 +1,37 @@
+/* Copyright 2017 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Aggregate of old notification counts that have been deleted out of the
+-- main event_push_actions table. This count does not include those that were
+-- highlights, as they remain in the event_push_actions table.
+CREATE TABLE event_push_summary (
+    user_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    notif_count BIGINT NOT NULL,
+    stream_ordering BIGINT NOT NULL
+);
+
+CREATE INDEX event_push_summary_user_rm ON event_push_summary(user_id, room_id);
+
+
+-- The stream ordering up to which we have aggregated the event_push_actions
+-- table into event_push_summary
+CREATE TABLE event_push_summary_stream_ordering (
+    Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE,  -- Makes sure this table only has one row.
+    stream_ordering BIGINT NOT NULL,
+    CHECK (Lock='X')
+);
+
+INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0);
diff --git a/synapse/storage/schema/delta/40/pushers.sql b/synapse/storage/schema/delta/40/pushers.sql
new file mode 100644
index 0000000000..054a223f14
--- /dev/null
+++ b/synapse/storage/schema/delta/40/pushers.sql
@@ -0,0 +1,39 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS pushers2 (
+    id BIGINT PRIMARY KEY,
+    user_name TEXT NOT NULL,
+    access_token BIGINT DEFAULT NULL,
+    profile_tag TEXT NOT NULL,
+    kind TEXT NOT NULL,
+    app_id TEXT NOT NULL,
+    app_display_name TEXT NOT NULL,
+    device_display_name TEXT NOT NULL,
+    pushkey TEXT NOT NULL,
+    ts BIGINT NOT NULL,
+    lang TEXT,
+    data TEXT,
+    last_stream_ordering INTEGER,
+    last_success BIGINT,
+    failing_since BIGINT,
+    UNIQUE (app_id, pushkey, user_name)
+);
+
+INSERT INTO pushers2 SELECT * FROM PUSHERS;
+
+DROP TABLE PUSHERS;
+
+ALTER TABLE pushers2 RENAME TO pushers;
diff --git a/synapse/storage/schema/delta/23/refresh_tokens.sql b/synapse/storage/schema/delta/41/device_list_stream_idx.sql
index 34db0cf12b..b7bee8b692 100644
--- a/synapse/storage/schema/delta/23/refresh_tokens.sql
+++ b/synapse/storage/schema/delta/41/device_list_stream_idx.sql
@@ -1,4 +1,4 @@
-/* Copyright 2015, 2016 OpenMarket Ltd
+/* Copyright 2017 Vector Creations Ltd
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -13,9 +13,5 @@
  * limitations under the License.
  */
 
-CREATE TABLE IF NOT EXISTS refresh_tokens(
-    id INTEGER PRIMARY KEY,
-    token TEXT NOT NULL,
-    user_id TEXT NOT NULL,
-    UNIQUE (token)
-);
+INSERT into background_updates (update_name, progress_json)
+    VALUES ('device_lists_stream_idx', '{}');
diff --git a/synapse/storage/schema/delta/41/device_outbound_index.sql b/synapse/storage/schema/delta/41/device_outbound_index.sql
new file mode 100644
index 0000000000..62f0b9892b
--- /dev/null
+++ b/synapse/storage/schema/delta/41/device_outbound_index.sql
@@ -0,0 +1,16 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes(stream_id);
diff --git a/synapse/storage/schema/delta/41/event_search_event_id_idx.sql b/synapse/storage/schema/delta/41/event_search_event_id_idx.sql
new file mode 100644
index 0000000000..5d9cfecf36
--- /dev/null
+++ b/synapse/storage/schema/delta/41/event_search_event_id_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT into background_updates (update_name, progress_json)
+    VALUES ('event_search_event_id_idx', '{}');
diff --git a/synapse/storage/schema/delta/41/ratelimit.sql b/synapse/storage/schema/delta/41/ratelimit.sql
new file mode 100644
index 0000000000..a194bf0238
--- /dev/null
+++ b/synapse/storage/schema/delta/41/ratelimit.sql
@@ -0,0 +1,22 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE ratelimit_override (
+    user_id TEXT NOT NULL,
+    messages_per_second BIGINT,
+    burst_count BIGINT
+);
+
+CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override(user_id);
diff --git a/synapse/storage/schema/delta/42/current_state_delta.sql b/synapse/storage/schema/delta/42/current_state_delta.sql
new file mode 100644
index 0000000000..d28851aff8
--- /dev/null
+++ b/synapse/storage/schema/delta/42/current_state_delta.sql
@@ -0,0 +1,26 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+CREATE TABLE current_state_delta_stream (
+    stream_id BIGINT NOT NULL,
+    room_id TEXT NOT NULL,
+    type TEXT NOT NULL,
+    state_key TEXT NOT NULL,
+    event_id TEXT,  -- Is null if the key was removed
+    prev_event_id TEXT  -- Is null if the key was added
+);
+
+CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream(stream_id);
diff --git a/synapse/storage/schema/delta/42/device_list_last_id.sql b/synapse/storage/schema/delta/42/device_list_last_id.sql
new file mode 100644
index 0000000000..9ab8c14fa3
--- /dev/null
+++ b/synapse/storage/schema/delta/42/device_list_last_id.sql
@@ -0,0 +1,33 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- Table of last stream_id that we sent to destination for user_id. This is
+-- used to fill out the `prev_id` fields of outbound device list updates.
+CREATE TABLE device_lists_outbound_last_success (
+    destination TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    stream_id BIGINT NOT NULL
+);
+
+INSERT INTO device_lists_outbound_last_success
+    SELECT destination, user_id, coalesce(max(stream_id), 0) as stream_id
+        FROM device_lists_outbound_pokes
+        WHERE sent = (1 = 1)  -- sqlite doesn't have inbuilt boolean values
+        GROUP BY destination, user_id;
+
+CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success(
+    destination, user_id, stream_id
+);
diff --git a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql b/synapse/storage/schema/delta/42/event_auth_state_only.sql
index bb225dafbf..b8821ac759 100644
--- a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql
+++ b/synapse/storage/schema/delta/42/event_auth_state_only.sql
@@ -1,4 +1,4 @@
-/* Copyright 2016 OpenMarket Ltd
+/* Copyright 2017 Vector Creations Ltd
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -14,4 +14,4 @@
  */
 
 INSERT INTO background_updates (update_name, progress_json) VALUES
-  ('refresh_tokens_device_index', '{}');
+  ('event_auth_state_only', '{}');
diff --git a/synapse/storage/schema/delta/42/user_dir.py b/synapse/storage/schema/delta/42/user_dir.py
new file mode 100644
index 0000000000..ea6a18196d
--- /dev/null
+++ b/synapse/storage/schema/delta/42/user_dir.py
@@ -0,0 +1,84 @@
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.storage.prepare_database import get_statements
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+
+logger = logging.getLogger(__name__)
+
+
+BOTH_TABLES = """
+CREATE TABLE user_directory_stream_pos (
+    Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE,  -- Makes sure this table only has one row.
+    stream_id BIGINT,
+    CHECK (Lock='X')
+);
+
+INSERT INTO user_directory_stream_pos (stream_id) VALUES (null);
+
+CREATE TABLE user_directory (
+    user_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,  -- A room_id that we know the user is joined to
+    display_name TEXT,
+    avatar_url TEXT
+);
+
+CREATE INDEX user_directory_room_idx ON user_directory(room_id);
+CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id);
+
+CREATE TABLE users_in_pubic_room (
+    user_id TEXT NOT NULL,
+    room_id TEXT NOT NULL  -- A room_id that we know is public
+);
+
+CREATE INDEX users_in_pubic_room_room_idx ON users_in_pubic_room(room_id);
+CREATE UNIQUE INDEX users_in_pubic_room_user_idx ON users_in_pubic_room(user_id);
+"""
+
+
+POSTGRES_TABLE = """
+CREATE TABLE user_directory_search (
+    user_id TEXT NOT NULL,
+    vector tsvector
+);
+
+CREATE INDEX user_directory_search_fts_idx ON user_directory_search USING gin(vector);
+CREATE UNIQUE INDEX user_directory_search_user_idx ON user_directory_search(user_id);
+"""
+
+
+SQLITE_TABLE = """
+CREATE VIRTUAL TABLE user_directory_search
+    USING fts4 ( user_id, value );
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+    for statement in get_statements(BOTH_TABLES.splitlines()):
+        cur.execute(statement)
+
+    if isinstance(database_engine, PostgresEngine):
+        for statement in get_statements(POSTGRES_TABLE.splitlines()):
+            cur.execute(statement)
+    elif isinstance(database_engine, Sqlite3Engine):
+        for statement in get_statements(SQLITE_TABLE.splitlines()):
+            cur.execute(statement)
+    else:
+        raise Exception("Unrecognized database engine")
+
+
+def run_upgrade(*args, **kwargs):
+    pass
diff --git a/synapse/storage/schema/delta/43/blocked_rooms.sql b/synapse/storage/schema/delta/43/blocked_rooms.sql
new file mode 100644
index 0000000000..0e3cd143ff
--- /dev/null
+++ b/synapse/storage/schema/delta/43/blocked_rooms.sql
@@ -0,0 +1,21 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE blocked_rooms (
+    room_id TEXT NOT NULL,
+    user_id TEXT NOT NULL  -- Admin who blocked the room
+);
+
+CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id);
diff --git a/synapse/storage/schema/delta/43/quarantine_media.sql b/synapse/storage/schema/delta/43/quarantine_media.sql
new file mode 100644
index 0000000000..630907ec4f
--- /dev/null
+++ b/synapse/storage/schema/delta/43/quarantine_media.sql
@@ -0,0 +1,17 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE local_media_repository ADD COLUMN quarantined_by TEXT;
+ALTER TABLE remote_media_cache ADD COLUMN quarantined_by TEXT;
diff --git a/synapse/storage/schema/delta/33/refreshtoken_device.sql b/synapse/storage/schema/delta/43/url_cache.sql
index 290bd6da86..45ebe020da 100644
--- a/synapse/storage/schema/delta/33/refreshtoken_device.sql
+++ b/synapse/storage/schema/delta/43/url_cache.sql
@@ -1,4 +1,4 @@
-/* Copyright 2016 OpenMarket Ltd
+/* Copyright 2017 Vector Creations Ltd
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -13,4 +13,4 @@
  * limitations under the License.
  */
 
-ALTER TABLE refresh_tokens ADD COLUMN device_id TEXT;
+ALTER TABLE local_media_repository ADD COLUMN url_cache TEXT;
diff --git a/synapse/storage/schema/delta/43/user_share.sql b/synapse/storage/schema/delta/43/user_share.sql
new file mode 100644
index 0000000000..ee7062abe4
--- /dev/null
+++ b/synapse/storage/schema/delta/43/user_share.sql
@@ -0,0 +1,33 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Table keeping track of who shares a room with who. We only keep track
+-- of this for local users, so `user_id` is local users only (but we do keep track
+-- of which remote users share a room)
+CREATE TABLE users_who_share_rooms (
+    user_id TEXT NOT NULL,
+    other_user_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    share_private BOOLEAN NOT NULL  -- is the shared room private? i.e. they share a private room
+);
+
+
+CREATE UNIQUE INDEX users_who_share_rooms_u_idx ON users_who_share_rooms(user_id, other_user_id);
+CREATE INDEX users_who_share_rooms_r_idx ON users_who_share_rooms(room_id);
+CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id);
+
+
+-- Make sure that we populate the table initially
+UPDATE user_directory_stream_pos SET stream_id = NULL;
diff --git a/synapse/storage/schema/delta/44/expire_url_cache.sql b/synapse/storage/schema/delta/44/expire_url_cache.sql
new file mode 100644
index 0000000000..b12f9b2ebf
--- /dev/null
+++ b/synapse/storage/schema/delta/44/expire_url_cache.sql
@@ -0,0 +1,41 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- this didn't work on SQLite 3.7 (because of lack of partial indexes), so was
+-- removed and replaced with 46/local_media_repository_url_idx.sql.
+--
+-- CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
+
+-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support
+-- indices on expressions until 3.9.
+CREATE TABLE local_media_repository_url_cache_new(
+    url TEXT,
+    response_code INTEGER,
+    etag TEXT,
+    expires_ts BIGINT,
+    og TEXT,
+    media_id TEXT,
+    download_ts BIGINT
+);
+
+INSERT INTO local_media_repository_url_cache_new
+    SELECT url, response_code, etag, expires + download_ts, og, media_id, download_ts FROM local_media_repository_url_cache;
+
+DROP TABLE local_media_repository_url_cache;
+ALTER TABLE local_media_repository_url_cache_new RENAME TO local_media_repository_url_cache;
+
+CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts);
+CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts);
+CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id);
diff --git a/synapse/storage/schema/delta/45/group_server.sql b/synapse/storage/schema/delta/45/group_server.sql
new file mode 100644
index 0000000000..b2333848a0
--- /dev/null
+++ b/synapse/storage/schema/delta/45/group_server.sql
@@ -0,0 +1,167 @@
+/* Copyright 2017 Vector Creations Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE groups (
+    group_id TEXT NOT NULL,
+    name TEXT,  -- the display name of the room
+    avatar_url TEXT,
+    short_description TEXT,
+    long_description TEXT
+);
+
+CREATE UNIQUE INDEX groups_idx ON groups(group_id);
+
+
+-- list of users the group server thinks are joined
+CREATE TABLE group_users (
+    group_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    is_admin BOOLEAN NOT NULL,
+    is_public BOOLEAN NOT NULL  -- whether the users membership can be seen by everyone
+);
+
+
+CREATE INDEX groups_users_g_idx ON group_users(group_id, user_id);
+CREATE INDEX groups_users_u_idx ON group_users(user_id);
+
+-- list of users the group server thinks are invited
+CREATE TABLE group_invites (
+    group_id TEXT NOT NULL,
+    user_id TEXT NOT NULL
+);
+
+CREATE INDEX groups_invites_g_idx ON group_invites(group_id, user_id);
+CREATE INDEX groups_invites_u_idx ON group_invites(user_id);
+
+
+CREATE TABLE group_rooms (
+    group_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    is_public BOOLEAN NOT NULL  -- whether the room can be seen by everyone
+);
+
+CREATE UNIQUE INDEX groups_rooms_g_idx ON group_rooms(group_id, room_id);
+CREATE INDEX groups_rooms_r_idx ON group_rooms(room_id);
+
+
+-- Rooms to include in the summary
+CREATE TABLE group_summary_rooms (
+    group_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    category_id TEXT NOT NULL,
+    room_order BIGINT NOT NULL,
+    is_public BOOLEAN NOT NULL, -- whether the room should be show to everyone
+    UNIQUE (group_id, category_id, room_id, room_order),
+    CHECK (room_order > 0)
+);
+
+CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms(group_id, room_id, category_id);
+
+
+-- Categories to include in the summary
+CREATE TABLE group_summary_room_categories (
+    group_id TEXT NOT NULL,
+    category_id TEXT NOT NULL,
+    cat_order BIGINT NOT NULL,
+    UNIQUE (group_id, category_id, cat_order),
+    CHECK (cat_order > 0)
+);
+
+-- The categories in the group
+CREATE TABLE group_room_categories (
+    group_id TEXT NOT NULL,
+    category_id TEXT NOT NULL,
+    profile TEXT NOT NULL,
+    is_public BOOLEAN NOT NULL, -- whether the category should be show to everyone
+    UNIQUE (group_id, category_id)
+);
+
+-- The users to include in the group summary
+CREATE TABLE group_summary_users (
+    group_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    role_id TEXT NOT NULL,
+    user_order BIGINT NOT NULL,
+    is_public BOOLEAN NOT NULL  -- whether the user should be show to everyone
+);
+
+CREATE INDEX group_summary_users_g_idx ON group_summary_users(group_id);
+
+-- The roles to include in the group summary
+CREATE TABLE group_summary_roles (
+    group_id TEXT NOT NULL,
+    role_id TEXT NOT NULL,
+    role_order BIGINT NOT NULL,
+    UNIQUE (group_id, role_id, role_order),
+    CHECK (role_order > 0)
+);
+
+
+-- The roles in a groups
+CREATE TABLE group_roles (
+    group_id TEXT NOT NULL,
+    role_id TEXT NOT NULL,
+    profile TEXT NOT NULL,
+    is_public BOOLEAN NOT NULL,  -- whether the role should be show to everyone
+    UNIQUE (group_id, role_id)
+);
+
+
+-- List of  attestations we've given out and need to renew
+CREATE TABLE group_attestations_renewals (
+    group_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    valid_until_ms BIGINT NOT NULL
+);
+
+CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals(group_id, user_id);
+CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals(user_id);
+CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals(valid_until_ms);
+
+
+-- List of attestations we've received from remotes and are interested in.
+CREATE TABLE group_attestations_remote (
+    group_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    valid_until_ms BIGINT NOT NULL,
+    attestation_json TEXT NOT NULL
+);
+
+CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id);
+CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id);
+CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms);
+
+
+-- The group membership for the HS's users
+CREATE TABLE local_group_membership (
+    group_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    is_admin BOOLEAN NOT NULL,
+    membership TEXT NOT NULL,
+    is_publicised BOOLEAN NOT NULL,  -- if the user is publicising their membership
+    content TEXT NOT NULL
+);
+
+CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id);
+CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id);
+
+
+CREATE TABLE local_group_updates (
+    stream_id BIGINT NOT NULL,
+    group_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    type TEXT NOT NULL,
+    content TEXT NOT NULL
+);
diff --git a/synapse/storage/schema/delta/45/profile_cache.sql b/synapse/storage/schema/delta/45/profile_cache.sql
new file mode 100644
index 0000000000..e5ddc84df0
--- /dev/null
+++ b/synapse/storage/schema/delta/45/profile_cache.sql
@@ -0,0 +1,28 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- A subset of remote users whose profiles we have cached.
+-- Whether a user is in this table or not is defined by the storage function
+-- `is_subscribed_remote_profile_for_user`
+CREATE TABLE remote_profile_cache (
+    user_id TEXT NOT NULL,
+    displayname TEXT,
+    avatar_url TEXT,
+    last_check BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id);
+CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check);
diff --git a/synapse/storage/schema/delta/46/drop_refresh_tokens.sql b/synapse/storage/schema/delta/46/drop_refresh_tokens.sql
new file mode 100644
index 0000000000..68c48a89a9
--- /dev/null
+++ b/synapse/storage/schema/delta/46/drop_refresh_tokens.sql
@@ -0,0 +1,17 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* we no longer use (or create) the refresh_tokens table */
+DROP TABLE IF EXISTS refresh_tokens;
diff --git a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql
new file mode 100644
index 0000000000..bb307889c1
--- /dev/null
+++ b/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql
@@ -0,0 +1,35 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- drop the unique constraint on deleted_pushers so that we can just insert
+-- into it rather than upserting.
+
+CREATE TABLE deleted_pushers2 (
+    stream_id BIGINT NOT NULL,
+    app_id TEXT NOT NULL,
+    pushkey TEXT NOT NULL,
+    user_id TEXT NOT NULL
+);
+
+INSERT INTO deleted_pushers2 (stream_id, app_id, pushkey, user_id)
+    SELECT stream_id, app_id, pushkey, user_id from deleted_pushers;
+
+DROP TABLE deleted_pushers;
+ALTER TABLE deleted_pushers2 RENAME TO deleted_pushers;
+
+-- create the index after doing the inserts because that's more efficient.
+-- it also means we can give it the same name as the old one without renaming.
+CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id);
+
diff --git a/synapse/storage/schema/delta/46/group_server.sql b/synapse/storage/schema/delta/46/group_server.sql
new file mode 100644
index 0000000000..097679bc9a
--- /dev/null
+++ b/synapse/storage/schema/delta/46/group_server.sql
@@ -0,0 +1,32 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE groups_new (
+    group_id TEXT NOT NULL,
+    name TEXT,  -- the display name of the room
+    avatar_url TEXT,
+    short_description TEXT,
+    long_description TEXT,
+    is_public BOOL NOT NULL -- whether non-members can access group APIs
+);
+
+-- NB: awful hack to get the default to be true on postgres and 1 on sqlite
+INSERT INTO groups_new
+    SELECT group_id, name, avatar_url, short_description, long_description, (1=1) FROM groups;
+
+DROP TABLE groups;
+ALTER TABLE groups_new RENAME TO groups;
+
+CREATE UNIQUE INDEX groups_idx ON groups(group_id);
diff --git a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql
new file mode 100644
index 0000000000..bbfc7f5d1a
--- /dev/null
+++ b/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql
@@ -0,0 +1,24 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- register a background update which will recreate the
+-- local_media_repository_url_idx index.
+--
+-- We do this as a bg update not because it is a particularly onerous
+-- operation, but because we'd like it to be a partial index if possible, and
+-- the background_index_update code will understand whether we are on
+-- postgres or sqlite and behave accordingly.
+INSERT INTO background_updates (update_name, progress_json) VALUES
+    ('local_media_repository_url_idx', '{}');
diff --git a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql
new file mode 100644
index 0000000000..cb0d5a2576
--- /dev/null
+++ b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql
@@ -0,0 +1,35 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- change the user_directory table to also cover global local user profiles
+-- rather than just profiles within specific rooms.
+
+CREATE TABLE user_directory2 (
+    user_id TEXT NOT NULL,
+    room_id TEXT,
+    display_name TEXT,
+    avatar_url TEXT
+);
+
+INSERT INTO user_directory2(user_id, room_id, display_name, avatar_url)
+    SELECT user_id, room_id, display_name, avatar_url from user_directory;
+
+DROP TABLE user_directory;
+ALTER TABLE user_directory2 RENAME TO user_directory;
+
+-- create indexes after doing the inserts because that's more efficient.
+-- it also means we can give it the same name as the old one without renaming.
+CREATE INDEX user_directory_room_idx ON user_directory(room_id);
+CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id);
diff --git a/synapse/storage/schema/delta/46/user_dir_typos.sql b/synapse/storage/schema/delta/46/user_dir_typos.sql
new file mode 100644
index 0000000000..d9505f8da1
--- /dev/null
+++ b/synapse/storage/schema/delta/46/user_dir_typos.sql
@@ -0,0 +1,24 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- this is just embarassing :|
+ALTER TABLE users_in_pubic_room RENAME TO users_in_public_rooms;
+
+-- this is only 300K rows on matrix.org and takes ~3s to generate the index,
+-- so is hopefully not going to block anyone else for that long...
+CREATE INDEX users_in_public_rooms_room_idx ON users_in_public_rooms(room_id);
+CREATE UNIQUE INDEX users_in_public_rooms_user_idx ON users_in_public_rooms(user_id);
+DROP INDEX users_in_pubic_room_room_idx;
+DROP INDEX users_in_pubic_room_user_idx;
diff --git a/synapse/storage/schema/delta/47/last_access_media.sql b/synapse/storage/schema/delta/47/last_access_media.sql
new file mode 100644
index 0000000000..f505fb22b5
--- /dev/null
+++ b/synapse/storage/schema/delta/47/last_access_media.sql
@@ -0,0 +1,16 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT;
diff --git a/synapse/storage/schema/delta/47/postgres_fts_gin.sql b/synapse/storage/schema/delta/47/postgres_fts_gin.sql
new file mode 100644
index 0000000000..31d7a817eb
--- /dev/null
+++ b/synapse/storage/schema/delta/47/postgres_fts_gin.sql
@@ -0,0 +1,17 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT into background_updates (update_name, progress_json)
+    VALUES ('event_search_postgres_gin', '{}');
diff --git a/synapse/storage/schema/delta/47/push_actions_staging.sql b/synapse/storage/schema/delta/47/push_actions_staging.sql
new file mode 100644
index 0000000000..edccf4a96f
--- /dev/null
+++ b/synapse/storage/schema/delta/47/push_actions_staging.sql
@@ -0,0 +1,28 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Temporary staging area for push actions that have been calculated for an
+-- event, but the event hasn't yet been persisted.
+-- When the event is persisted the rows are moved over to the
+-- event_push_actions table.
+CREATE TABLE event_push_actions_staging (
+    event_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    actions TEXT NOT NULL,
+    notif SMALLINT NOT NULL,
+    highlight SMALLINT NOT NULL
+);
+
+CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id);
diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/schema/delta/47/state_group_seq.py
new file mode 100644
index 0000000000..f6766501d2
--- /dev/null
+++ b/synapse/storage/schema/delta/47/state_group_seq.py
@@ -0,0 +1,37 @@
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.engines import PostgresEngine
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+    if isinstance(database_engine, PostgresEngine):
+        # if we already have some state groups, we want to start making new
+        # ones with a higher id.
+        cur.execute("SELECT max(id) FROM state_groups")
+        row = cur.fetchone()
+
+        if row[0] is None:
+            start_val = 1
+        else:
+            start_val = row[0] + 1
+
+        cur.execute(
+            "CREATE SEQUENCE state_group_id_seq START WITH %s",
+            (start_val, ),
+        )
+
+
+def run_upgrade(*args, **kwargs):
+    pass
diff --git a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql
new file mode 100644
index 0000000000..9248b0b24a
--- /dev/null
+++ b/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT into background_updates (update_name, progress_json)
+    VALUES ('user_ips_last_seen_index', '{}');
diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/schema/delta/48/group_unique_indexes.py
new file mode 100644
index 0000000000..2233af87d7
--- /dev/null
+++ b/synapse/storage/schema/delta/48/group_unique_indexes.py
@@ -0,0 +1,57 @@
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import get_statements
+
+FIX_INDEXES = """
+-- rebuild indexes as uniques
+DROP INDEX groups_invites_g_idx;
+CREATE UNIQUE INDEX group_invites_g_idx ON group_invites(group_id, user_id);
+DROP INDEX groups_users_g_idx;
+CREATE UNIQUE INDEX group_users_g_idx ON group_users(group_id, user_id);
+
+-- rename other indexes to actually match their table names..
+DROP INDEX groups_users_u_idx;
+CREATE INDEX group_users_u_idx ON group_users(user_id);
+DROP INDEX groups_invites_u_idx;
+CREATE INDEX group_invites_u_idx ON group_invites(user_id);
+DROP INDEX groups_rooms_g_idx;
+CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id);
+DROP INDEX groups_rooms_r_idx;
+CREATE INDEX group_rooms_r_idx ON group_rooms(room_id);
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+    rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid"
+
+    # remove duplicates from group_users & group_invites tables
+    cur.execute("""
+        DELETE FROM group_users WHERE %s NOT IN (
+           SELECT min(%s) FROM group_users GROUP BY group_id, user_id
+        );
+    """ % (rowid, rowid))
+    cur.execute("""
+        DELETE FROM group_invites WHERE %s NOT IN (
+           SELECT min(%s) FROM group_invites GROUP BY group_id, user_id
+        );
+    """ % (rowid, rowid))
+
+    for statement in get_statements(FIX_INDEXES.splitlines()):
+        cur.execute(statement)
+
+
+def run_upgrade(*args, **kwargs):
+    pass
diff --git a/synapse/storage/schema/delta/48/groups_joinable.sql b/synapse/storage/schema/delta/48/groups_joinable.sql
new file mode 100644
index 0000000000..ce26eaf0c9
--- /dev/null
+++ b/synapse/storage/schema/delta/48/groups_joinable.sql
@@ -0,0 +1,22 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* 
+ * This isn't a real ENUM because sqlite doesn't support it
+ * and we use a default of NULL for inserted rows and interpret
+ * NULL at the python store level as necessary so that existing
+ * rows are given the correct default policy.
+ */
+ALTER TABLE groups ADD COLUMN join_policy TEXT NOT NULL DEFAULT 'invite';
diff --git a/synapse/storage/schema/schema_version.sql b/synapse/storage/schema/schema_version.sql
index a7ade69986..42e5cb6df5 100644
--- a/synapse/storage/schema/schema_version.sql
+++ b/synapse/storage/schema/schema_version.sql
@@ -25,3 +25,10 @@ CREATE TABLE IF NOT EXISTS applied_schema_deltas(
     file TEXT NOT NULL,
     UNIQUE(version, file)
 );
+
+-- a list of schema files we have loaded on behalf of dynamic modules
+CREATE TABLE IF NOT EXISTS applied_module_schemas(
+    module_name TEXT NOT NULL,
+    file TEXT NOT NULL,
+    UNIQUE(module_name, file)
+);
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index 8f2b3c4435..6ba3e59889 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -13,28 +13,35 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from collections import namedtuple
+import logging
+import re
+import simplejson as json
+
 from twisted.internet import defer
 
 from .background_updates import BackgroundUpdateStore
 from synapse.api.errors import SynapseError
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
-import logging
-import re
-import ujson as json
-
 
 logger = logging.getLogger(__name__)
 
+SearchEntry = namedtuple('SearchEntry', [
+    'key', 'value', 'event_id', 'room_id', 'stream_ordering',
+    'origin_server_ts',
+])
+
 
 class SearchStore(BackgroundUpdateStore):
 
     EVENT_SEARCH_UPDATE_NAME = "event_search"
     EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
     EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
+    EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
 
-    def __init__(self, hs):
-        super(SearchStore, self).__init__(hs)
+    def __init__(self, db_conn, hs):
+        super(SearchStore, self).__init__(db_conn, hs)
         self.register_background_update_handler(
             self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
         )
@@ -42,23 +49,35 @@ class SearchStore(BackgroundUpdateStore):
             self.EVENT_SEARCH_ORDER_UPDATE_NAME,
             self._background_reindex_search_order
         )
-        self.register_background_update_handler(
+
+        # we used to have a background update to turn the GIN index into a
+        # GIST one; we no longer do that (obviously) because we actually want
+        # a GIN index. However, it's possible that some people might still have
+        # the background update queued, so we register a handler to clear the
+        # background update.
+        self.register_noop_background_update(
             self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME,
-            self._background_reindex_gist_search
+        )
+
+        self.register_background_update_handler(
+            self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME,
+            self._background_reindex_gin_search
         )
 
     @defer.inlineCallbacks
     def _background_reindex_search(self, progress, batch_size):
+        # we work through the events table from highest stream id to lowest
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
 
-        INSERT_CLUMP_SIZE = 1000
         TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
 
         def reindex_search_txn(txn):
             sql = (
-                "SELECT stream_ordering, event_id, room_id, type, content FROM events"
+                "SELECT stream_ordering, event_id, room_id, type, json, "
+                " origin_server_ts FROM events"
+                " JOIN event_json USING (room_id, event_id)"
                 " WHERE ? <= stream_ordering AND stream_ordering < ?"
                 " AND (%s)"
                 " ORDER BY stream_ordering DESC"
@@ -67,6 +86,10 @@ class SearchStore(BackgroundUpdateStore):
 
             txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
 
+            # we could stream straight from the results into
+            # store_search_entries_txn with a generator function, but that
+            # would mean having two cursors open on the database at once.
+            # Instead we just build a list of results.
             rows = self.cursor_to_dict(txn)
             if not rows:
                 return 0
@@ -79,9 +102,12 @@ class SearchStore(BackgroundUpdateStore):
                     event_id = row["event_id"]
                     room_id = row["room_id"]
                     etype = row["type"]
+                    stream_ordering = row["stream_ordering"]
+                    origin_server_ts = row["origin_server_ts"]
                     try:
-                        content = json.loads(row["content"])
-                    except:
+                        event_json = json.loads(row["json"])
+                        content = event_json["content"]
+                    except Exception:
                         continue
 
                     if etype == "m.room.message":
@@ -93,6 +119,8 @@ class SearchStore(BackgroundUpdateStore):
                     elif etype == "m.room.name":
                         key = "content.name"
                         value = content["name"]
+                    else:
+                        raise Exception("unexpected event type %s" % etype)
                 except (KeyError, AttributeError):
                     # If the event is missing a necessary field then
                     # skip over it.
@@ -103,25 +131,16 @@ class SearchStore(BackgroundUpdateStore):
                     # then skip over it
                     continue
 
-                event_search_rows.append((event_id, room_id, key, value))
+                event_search_rows.append(SearchEntry(
+                    key=key,
+                    value=value,
+                    event_id=event_id,
+                    room_id=room_id,
+                    stream_ordering=stream_ordering,
+                    origin_server_ts=origin_server_ts,
+                ))
 
-            if isinstance(self.database_engine, PostgresEngine):
-                sql = (
-                    "INSERT INTO event_search (event_id, room_id, key, vector)"
-                    " VALUES (?,?,?,to_tsvector('english', ?))"
-                )
-            elif isinstance(self.database_engine, Sqlite3Engine):
-                sql = (
-                    "INSERT INTO event_search (event_id, room_id, key, value)"
-                    " VALUES (?,?,?,?)"
-                )
-            else:
-                # This should be unreachable.
-                raise Exception("Unrecognized database engine")
-
-            for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE):
-                clump = event_search_rows[index:index + INSERT_CLUMP_SIZE]
-                txn.executemany(sql, clump)
+            self.store_search_entries_txn(txn, event_search_rows)
 
             progress = {
                 "target_min_stream_id_inclusive": target_min_stream_id,
@@ -145,25 +164,48 @@ class SearchStore(BackgroundUpdateStore):
         defer.returnValue(result)
 
     @defer.inlineCallbacks
-    def _background_reindex_gist_search(self, progress, batch_size):
+    def _background_reindex_gin_search(self, progress, batch_size):
+        """This handles old synapses which used GIST indexes, if any;
+        converting them back to be GIN as per the actual schema.
+        """
+
         def create_index(conn):
             conn.rollback()
-            conn.set_session(autocommit=True)
-            c = conn.cursor()
 
-            c.execute(
-                "CREATE INDEX CONCURRENTLY event_search_fts_idx_gist"
-                " ON event_search USING GIST (vector)"
-            )
+            # we have to set autocommit, because postgres refuses to
+            # CREATE INDEX CONCURRENTLY without it.
+            conn.set_session(autocommit=True)
 
-            c.execute("DROP INDEX event_search_fts_idx")
+            try:
+                c = conn.cursor()
 
-            conn.set_session(autocommit=False)
+                # if we skipped the conversion to GIST, we may already/still
+                # have an event_search_fts_idx; unfortunately postgres 9.4
+                # doesn't support CREATE INDEX IF EXISTS so we just catch the
+                # exception and ignore it.
+                import psycopg2
+                try:
+                    c.execute(
+                        "CREATE INDEX CONCURRENTLY event_search_fts_idx"
+                        " ON event_search USING GIN (vector)"
+                    )
+                except psycopg2.ProgrammingError as e:
+                    logger.warn(
+                        "Ignoring error %r when trying to switch from GIST to GIN",
+                        e
+                    )
+
+                # we should now be able to delete the GIST index.
+                c.execute(
+                    "DROP INDEX IF EXISTS event_search_fts_idx_gist"
+                )
+            finally:
+                conn.set_session(autocommit=False)
 
         if isinstance(self.database_engine, PostgresEngine):
             yield self.runWithConnection(create_index)
 
-        yield self._end_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME)
+        yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME)
         defer.returnValue(1)
 
     @defer.inlineCallbacks
@@ -242,6 +284,85 @@ class SearchStore(BackgroundUpdateStore):
 
         defer.returnValue(num_rows)
 
+    def store_event_search_txn(self, txn, event, key, value):
+        """Add event to the search table
+
+        Args:
+            txn (cursor):
+            event (EventBase):
+            key (str):
+            value (str):
+        """
+        self.store_search_entries_txn(
+            txn,
+            (SearchEntry(
+                key=key,
+                value=value,
+                event_id=event.event_id,
+                room_id=event.room_id,
+                stream_ordering=event.internal_metadata.stream_ordering,
+                origin_server_ts=event.origin_server_ts,
+            ),),
+        )
+
+    def store_search_entries_txn(self, txn, entries):
+        """Add entries to the search table
+
+        Args:
+            txn (cursor):
+            entries (iterable[SearchEntry]):
+                entries to be added to the table
+        """
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = (
+                "INSERT INTO event_search"
+                " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
+                " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
+            )
+
+            args = ((
+                entry.event_id, entry.room_id, entry.key, entry.value,
+                entry.stream_ordering, entry.origin_server_ts,
+            ) for entry in entries)
+
+            # inserts to a GIN index are normally batched up into a pending
+            # list, and then all committed together once the list gets to a
+            # certain size. The trouble with that is that postgres (pre-9.5)
+            # uses work_mem to determine the length of the list, and work_mem
+            # is typically very large.
+            #
+            # We therefore reduce work_mem while we do the insert.
+            #
+            # (postgres 9.5 uses the separate gin_pending_list_limit setting,
+            # so doesn't suffer the same problem, but changing work_mem will
+            # be harmless)
+            #
+            # Note that we don't need to worry about restoring it on
+            # exception, because exceptions will cause the transaction to be
+            # rolled back, including the effects of the SET command.
+            #
+            # Also: we use SET rather than SET LOCAL because there's lots of
+            # other stuff going on in this transaction, which want to have the
+            # normal work_mem setting.
+
+            txn.execute("SET work_mem='256kB'")
+            txn.executemany(sql, args)
+            txn.execute("RESET work_mem")
+
+        elif isinstance(self.database_engine, Sqlite3Engine):
+            sql = (
+                "INSERT INTO event_search (event_id, room_id, key, value)"
+                " VALUES (?,?,?,?)"
+            )
+            args = ((
+                entry.event_id, entry.room_id, entry.key, entry.value,
+            ) for entry in entries)
+
+            txn.executemany(sql, args)
+        else:
+            # This should be unreachable.
+            raise Exception("Unrecognized database engine")
+
     @defer.inlineCallbacks
     def search_msgs(self, room_ids, search_term, keys):
         """Performs a full text search over events with given keys.
@@ -407,7 +528,7 @@ class SearchStore(BackgroundUpdateStore):
                 origin_server_ts, stream = pagination_token.split(",")
                 origin_server_ts = int(origin_server_ts)
                 stream = int(stream)
-            except:
+            except Exception:
                 raise SynapseError(400, "Invalid pagination token")
 
             clauses.append(
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index e1dca927d7..9e6eaaa532 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -22,12 +22,12 @@ from synapse.crypto.event_signing import compute_event_reference_hash
 from synapse.util.caches.descriptors import cached, cachedList
 
 
-class SignatureStore(SQLBaseStore):
-    """Persistence for event signatures and hashes"""
-
+class SignatureWorkerStore(SQLBaseStore):
     @cached()
     def get_event_reference_hash(self, event_id):
-        return self._get_event_reference_hashes_txn(event_id)
+        # This is a dummy function to allow get_event_reference_hashes
+        # to use its cache
+        raise NotImplementedError()
 
     @cachedList(cached_method_name="get_event_reference_hash",
                 list_name="event_ids", num_args=1)
@@ -72,7 +72,11 @@ class SignatureStore(SQLBaseStore):
             " WHERE event_id = ?"
         )
         txn.execute(query, (event_id, ))
-        return {k: v for k, v in txn.fetchall()}
+        return {k: v for k, v in txn}
+
+
+class SignatureStore(SignatureWorkerStore):
+    """Persistence for event signatures and hashes"""
 
     def _store_event_reference_hashes_txn(self, txn, events):
         """Store a hash for a PDU
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 84482d8285..ffa4246031 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -13,14 +13,18 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cached, cachedList
-from synapse.util.caches import intern_string
-from synapse.storage.engines import PostgresEngine
+from collections import namedtuple
+import logging
 
 from twisted.internet import defer
 
-import logging
+from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage.engines import PostgresEngine
+from synapse.util.caches import intern_string, CACHE_SIZE_FACTOR
+from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches.dictionary_cache import DictionaryCache
+from synapse.util.stringutils import to_ascii
+from ._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
@@ -28,45 +32,97 @@ logger = logging.getLogger(__name__)
 MAX_STATE_DELTA_HOPS = 100
 
 
-class StateStore(SQLBaseStore):
-    """ Keeps track of the state at a given event.
+class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))):
+    """Return type of get_state_group_delta that implements __len__, which lets
+    us use the itrable flag when caching
+    """
+    __slots__ = []
 
-    This is done by the concept of `state groups`. Every event is a assigned
-    a state group (identified by an arbitrary string), which references a
-    collection of state events. The current state of an event is then the
-    collection of state events referenced by the event's state group.
+    def __len__(self):
+        return len(self.delta_ids) if self.delta_ids else 0
 
-    Hence, every change in the current state causes a new state group to be
-    generated. However, if no change happens (e.g., if we get a message event
-    with only one parent it inherits the state group from its parent.)
 
-    There are three tables:
-      * `state_groups`: Stores group name, first event with in the group and
-        room id.
-      * `event_to_state_groups`: Maps events to state groups.
-      * `state_groups_state`: Maps state group to state events.
+class StateGroupWorkerStore(SQLBaseStore):
+    """The parts of StateGroupStore that can be called from workers.
     """
 
     STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
     STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
     CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
 
-    def __init__(self, hs):
-        super(StateStore, self).__init__(hs)
-        self.register_background_update_handler(
-            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
-            self._background_deduplicate_state,
+    def __init__(self, db_conn, hs):
+        super(StateGroupWorkerStore, self).__init__(db_conn, hs)
+
+        self._state_group_cache = DictionaryCache(
+            "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
         )
-        self.register_background_update_handler(
-            self.STATE_GROUP_INDEX_UPDATE_NAME,
-            self._background_index_state,
+
+    @cached(max_entries=100000, iterable=True)
+    def get_current_state_ids(self, room_id):
+        """Get the current state event ids for a room based on the
+        current_state_events table.
+
+        Args:
+            room_id (str)
+
+        Returns:
+            deferred: dict of (type, state_key) -> event_id
+        """
+        def _get_current_state_ids_txn(txn):
+            txn.execute(
+                """SELECT type, state_key, event_id FROM current_state_events
+                WHERE room_id = ?
+                """,
+                (room_id,)
+            )
+
+            return {
+                (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
+            }
+
+        return self.runInteraction(
+            "get_current_state_ids",
+            _get_current_state_ids_txn,
         )
-        self.register_background_index_update(
-            self.CURRENT_STATE_INDEX_UPDATE_NAME,
-            index_name="current_state_events_member_index",
-            table="current_state_events",
-            columns=["state_key"],
-            where_clause="type='m.room.member'",
+
+    @cached(max_entries=10000, iterable=True)
+    def get_state_group_delta(self, state_group):
+        """Given a state group try to return a previous group and a delta between
+        the old and the new.
+
+        Returns:
+            (prev_group, delta_ids), where both may be None.
+        """
+        def _get_state_group_delta_txn(txn):
+            prev_group = self._simple_select_one_onecol_txn(
+                txn,
+                table="state_group_edges",
+                keyvalues={
+                    "state_group": state_group,
+                },
+                retcol="prev_state_group",
+                allow_none=True,
+            )
+
+            if not prev_group:
+                return _GetStateGroupDelta(None, None)
+
+            delta_ids = self._simple_select_list_txn(
+                txn,
+                table="state_groups_state",
+                keyvalues={
+                    "state_group": state_group,
+                },
+                retcols=("type", "state_key", "event_id",)
+            )
+
+            return _GetStateGroupDelta(prev_group, {
+                (row["type"], row["state_key"]): row["event_id"]
+                for row in delta_ids
+            })
+        return self.runInteraction(
+            "get_state_group_delta",
+            _get_state_group_delta_txn,
         )
 
     @defer.inlineCallbacks
@@ -78,12 +134,26 @@ class StateStore(SQLBaseStore):
             event_ids,
         )
 
-        groups = set(event_to_groups.values())
+        groups = set(event_to_groups.itervalues())
         group_to_state = yield self._get_state_for_groups(groups)
 
         defer.returnValue(group_to_state)
 
     @defer.inlineCallbacks
+    def get_state_ids_for_group(self, state_group):
+        """Get the state IDs for the given state group
+
+        Args:
+            state_group (int)
+
+        Returns:
+            Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
+        """
+        group_to_state = yield self._get_state_for_groups((state_group,))
+
+        defer.returnValue(group_to_state[state_group])
+
+    @defer.inlineCallbacks
     def get_state_groups(self, room_id, event_ids):
         """ Get the state groups for the given list of event_ids
 
@@ -96,156 +166,21 @@ class StateStore(SQLBaseStore):
 
         state_event_map = yield self.get_events(
             [
-                ev_id for group_ids in group_to_ids.values()
-                for ev_id in group_ids.values()
+                ev_id for group_ids in group_to_ids.itervalues()
+                for ev_id in group_ids.itervalues()
             ],
             get_prev_content=False
         )
 
         defer.returnValue({
             group: [
-                state_event_map[v] for v in event_id_map.values() if v in state_event_map
+                state_event_map[v] for v in event_id_map.itervalues()
+                if v in state_event_map
             ]
-            for group, event_id_map in group_to_ids.items()
+            for group, event_id_map in group_to_ids.iteritems()
         })
 
-    def _have_persisted_state_group_txn(self, txn, state_group):
-        txn.execute(
-            "SELECT count(*) FROM state_groups WHERE id = ?",
-            (state_group,)
-        )
-        row = txn.fetchone()
-        return row and row[0]
-
-    def _store_mult_state_groups_txn(self, txn, events_and_contexts):
-        state_groups = {}
-        for event, context in events_and_contexts:
-            if event.internal_metadata.is_outlier():
-                continue
-
-            if context.current_state_ids is None:
-                continue
-
-            state_groups[event.event_id] = context.state_group
-
-            if self._have_persisted_state_group_txn(txn, context.state_group):
-                continue
-
-            self._simple_insert_txn(
-                txn,
-                table="state_groups",
-                values={
-                    "id": context.state_group,
-                    "room_id": event.room_id,
-                    "event_id": event.event_id,
-                },
-            )
-
-            # We persist as a delta if we can, while also ensuring the chain
-            # of deltas isn't tooo long, as otherwise read performance degrades.
-            if context.prev_group:
-                potential_hops = self._count_state_group_hops_txn(
-                    txn, context.prev_group
-                )
-            if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
-                self._simple_insert_txn(
-                    txn,
-                    table="state_group_edges",
-                    values={
-                        "state_group": context.state_group,
-                        "prev_state_group": context.prev_group,
-                    },
-                )
-
-                self._simple_insert_many_txn(
-                    txn,
-                    table="state_groups_state",
-                    values=[
-                        {
-                            "state_group": context.state_group,
-                            "room_id": event.room_id,
-                            "type": key[0],
-                            "state_key": key[1],
-                            "event_id": state_id,
-                        }
-                        for key, state_id in context.delta_ids.items()
-                    ],
-                )
-            else:
-                self._simple_insert_many_txn(
-                    txn,
-                    table="state_groups_state",
-                    values=[
-                        {
-                            "state_group": context.state_group,
-                            "room_id": event.room_id,
-                            "type": key[0],
-                            "state_key": key[1],
-                            "event_id": state_id,
-                        }
-                        for key, state_id in context.current_state_ids.items()
-                    ],
-                )
-
-        self._simple_insert_many_txn(
-            txn,
-            table="event_to_state_groups",
-            values=[
-                {
-                    "state_group": state_group_id,
-                    "event_id": event_id,
-                }
-                for event_id, state_group_id in state_groups.items()
-            ],
-        )
-
-    def _count_state_group_hops_txn(self, txn, state_group):
-        """Given a state group, count how many hops there are in the tree.
-
-        This is used to ensure the delta chains don't get too long.
-        """
-        if isinstance(self.database_engine, PostgresEngine):
-            sql = ("""
-                WITH RECURSIVE state(state_group) AS (
-                    VALUES(?::bigint)
-                    UNION ALL
-                    SELECT prev_state_group FROM state_group_edges e, state s
-                    WHERE s.state_group = e.state_group
-                )
-                SELECT count(*) FROM state;
-            """)
-
-            txn.execute(sql, (state_group,))
-            row = txn.fetchone()
-            if row and row[0]:
-                return row[0]
-            else:
-                return 0
-        else:
-            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
-            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
-            next_group = state_group
-            count = 0
-
-            while next_group:
-                next_group = self._simple_select_one_onecol_txn(
-                    txn,
-                    table="state_group_edges",
-                    keyvalues={"state_group": next_group},
-                    retcol="prev_state_group",
-                    allow_none=True,
-                )
-                if next_group:
-                    count += 1
-
-            return count
-
-    @cached(num_args=2, max_entries=100000, iterable=True)
-    def _get_state_group_from_group(self, group, types):
-        raise NotImplementedError()
-
-    @cachedList(cached_method_name="_get_state_group_from_group",
-                list_name="groups", num_args=2, inlineCallbacks=True)
+    @defer.inlineCallbacks
     def _get_state_groups_from_groups(self, groups, types):
         """Returns dictionary state_group -> (dict of (type, state_key) -> event id)
         """
@@ -305,6 +240,9 @@ class StateStore(SQLBaseStore):
                     (
                         "AND type = ? AND state_key = ?",
                         (etype, state_key)
+                    ) if state_key is not None else (
+                        "AND type = ?",
+                        (etype,)
                     )
                     for etype, state_key in types
                 ]
@@ -319,15 +257,24 @@ class StateStore(SQLBaseStore):
                     args.extend(where_args)
 
                     txn.execute(sql % (where_clause,), args)
-                    rows = self.cursor_to_dict(txn)
-                    for row in rows:
-                        key = (row["type"], row["state_key"])
-                        results[group][key] = row["event_id"]
+                    for row in txn:
+                        typ, state_key, event_id = row
+                        key = (typ, state_key)
+                        results[group][key] = event_id
         else:
+            where_args = []
+            where_clauses = []
+            wildcard_types = False
             if types is not None:
-                where_clause = "AND (%s)" % (
-                    " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
-                )
+                for typ in types:
+                    if typ[1] is None:
+                        where_clauses.append("(type = ?)")
+                        where_args.extend(typ[0])
+                        wildcard_types = True
+                    else:
+                        where_clauses.append("(type = ? AND state_key = ?)")
+                        where_args.extend([typ[0], typ[1]])
+                where_clause = "AND (%s)" % (" OR ".join(where_clauses))
             else:
                 where_clause = ""
 
@@ -344,23 +291,30 @@ class StateStore(SQLBaseStore):
                     # after we finish deduping state, which requires this func)
                     args = [next_group]
                     if types:
-                        args.extend(i for typ in types for i in typ)
+                        args.extend(where_args)
 
                     txn.execute(
                         "SELECT type, state_key, event_id FROM state_groups_state"
                         " WHERE state_group = ? %s" % (where_clause,),
                         args
                     )
-                    rows = txn.fetchall()
-                    results[group].update({
-                        (typ, state_key): event_id
-                        for typ, state_key, event_id in rows
+                    results[group].update(
+                        ((typ, state_key), event_id)
+                        for typ, state_key, event_id in txn
                         if (typ, state_key) not in results[group]
-                    })
+                    )
 
-                    # If the lengths match then we must have all the types,
-                    # so no need to go walk further down the tree.
-                    if types is not None and len(results[group]) == len(types):
+                    # If the number of entries in the (type,state_key)->event_id dict
+                    # matches the number of (type,state_keys) types we were searching
+                    # for, then we must have found them all, so no need to go walk
+                    # further down the tree... UNLESS our types filter contained
+                    # wildcards (i.e. Nones) in which case we have to do an exhaustive
+                    # search
+                    if (
+                        types is not None and
+                        not wildcard_types and
+                        len(results[group]) == len(types)
+                    ):
                         break
 
                     next_group = self._simple_select_one_onecol_txn(
@@ -393,21 +347,21 @@ class StateStore(SQLBaseStore):
             event_ids,
         )
 
-        groups = set(event_to_groups.values())
+        groups = set(event_to_groups.itervalues())
         group_to_state = yield self._get_state_for_groups(groups, types)
 
         state_event_map = yield self.get_events(
-            [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
+            [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()],
             get_prev_content=False
         )
 
         event_to_state = {
             event_id: {
                 k: state_event_map[v]
-                for k, v in group_to_state[group].items()
+                for k, v in group_to_state[group].iteritems()
                 if v in state_event_map
             }
-            for event_id, group in event_to_groups.items()
+            for event_id, group in event_to_groups.iteritems()
         }
 
         defer.returnValue({event: event_to_state[event] for event in event_ids})
@@ -430,12 +384,12 @@ class StateStore(SQLBaseStore):
             event_ids,
         )
 
-        groups = set(event_to_groups.values())
+        groups = set(event_to_groups.itervalues())
         group_to_state = yield self._get_state_for_groups(groups, types)
 
         event_to_state = {
             event_id: group_to_state[group]
-            for event_id, group in event_to_groups.items()
+            for event_id, group in event_to_groups.iteritems()
         }
 
         defer.returnValue({event: event_to_state[event] for event in event_ids})
@@ -474,8 +428,8 @@ class StateStore(SQLBaseStore):
         state_map = yield self.get_state_ids_for_events([event_id], types)
         defer.returnValue(state_map[event_id])
 
-    @cached(num_args=2, max_entries=10000)
-    def _get_state_group_for_event(self, room_id, event_id):
+    @cached(max_entries=50000)
+    def _get_state_group_for_event(self, event_id):
         return self._simple_select_one_onecol(
             table="event_to_state_groups",
             keyvalues={
@@ -517,20 +471,22 @@ class StateStore(SQLBaseStore):
                 where a `state_key` of `None` matches all state_keys for the
                 `type`.
         """
-        is_all, state_dict_ids = self._state_group_cache.get(group)
+        is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
 
         type_to_key = {}
         missing_types = set()
+
         for typ, state_key in types:
+            key = (typ, state_key)
             if state_key is None:
                 type_to_key[typ] = None
-                missing_types.add((typ, state_key))
+                missing_types.add(key)
             else:
                 if type_to_key.get(typ, object()) is not None:
                     type_to_key.setdefault(typ, set()).add(state_key)
 
-                if (typ, state_key) not in state_dict_ids:
-                    missing_types.add((typ, state_key))
+                if key not in state_dict_ids and key not in known_absent:
+                    missing_types.add(key)
 
         sentinel = object()
 
@@ -544,10 +500,10 @@ class StateStore(SQLBaseStore):
                 return True
             return False
 
-        got_all = not (missing_types or types is None)
+        got_all = is_all or not missing_types
 
         return {
-            k: v for k, v in state_dict_ids.items()
+            k: v for k, v in state_dict_ids.iteritems()
             if include(k[0], k[1])
         }, missing_types, got_all
 
@@ -561,7 +517,7 @@ class StateStore(SQLBaseStore):
         Args:
             group: The state group to lookup
         """
-        is_all, state_dict_ids = self._state_group_cache.get(group)
+        is_all, _, state_dict_ids = self._state_group_cache.get(group)
 
         return state_dict_ids, is_all
 
@@ -578,7 +534,7 @@ class StateStore(SQLBaseStore):
         missing_groups = []
         if types is not None:
             for group in set(groups):
-                state_dict_ids, missing_types, got_all = self._get_some_state_from_cache(
+                state_dict_ids, _, got_all = self._get_some_state_from_cache(
                     group, types
                 )
                 results[group] = state_dict_ids
@@ -606,46 +562,247 @@ class StateStore(SQLBaseStore):
 
             # Now we want to update the cache with all the things we fetched
             # from the database.
-            for group, group_state_dict in group_to_state_dict.items():
-                if types:
-                    # We delibrately put key -> None mappings into the cache to
-                    # cache absence of the key, on the assumption that if we've
-                    # explicitly asked for some types then we will probably ask
-                    # for them again.
-                    state_dict = {
-                        (intern_string(etype), intern_string(state_key)): None
-                        for (etype, state_key) in types
-                    }
-                    state_dict.update(results[group])
-                    results[group] = state_dict
-                else:
-                    state_dict = results[group]
-
-                state_dict.update({
-                    (intern_string(k[0]), intern_string(k[1])): v
-                    for k, v in group_state_dict.items()
-                })
+            for group, group_state_dict in group_to_state_dict.iteritems():
+                state_dict = results[group]
+
+                state_dict.update(
+                    ((intern_string(k[0]), intern_string(k[1])), to_ascii(v))
+                    for k, v in group_state_dict.iteritems()
+                )
 
                 self._state_group_cache.update(
                     cache_seq_num,
                     key=group,
                     value=state_dict,
                     full=(types is None),
+                    known_absent=types,
                 )
 
-        # Remove all the entries with None values. The None values were just
-        # used for bookkeeping in the cache.
-        for group, state_dict in results.items():
-            results[group] = {
-                key: event_id
-                for key, event_id in state_dict.items()
-                if event_id
-            }
-
         defer.returnValue(results)
 
-    def get_next_state_group(self):
-        return self._state_groups_id_gen.get_next()
+    def store_state_group(self, event_id, room_id, prev_group, delta_ids,
+                          current_state_ids):
+        """Store a new set of state, returning a newly assigned state group.
+
+        Args:
+            event_id (str): The event ID for which the state was calculated
+            room_id (str)
+            prev_group (int|None): A previous state group for the room, optional.
+            delta_ids (dict|None): The delta between state at `prev_group` and
+                `current_state_ids`, if `prev_group` was given. Same format as
+                `current_state_ids`.
+            current_state_ids (dict): The state to store. Map of (type, state_key)
+                to event_id.
+
+        Returns:
+            Deferred[int]: The state group ID
+        """
+        def _store_state_group_txn(txn):
+            if current_state_ids is None:
+                # AFAIK, this can never happen
+                raise Exception("current_state_ids cannot be None")
+
+            state_group = self.database_engine.get_next_state_group_id(txn)
+
+            self._simple_insert_txn(
+                txn,
+                table="state_groups",
+                values={
+                    "id": state_group,
+                    "room_id": room_id,
+                    "event_id": event_id,
+                },
+            )
+
+            # We persist as a delta if we can, while also ensuring the chain
+            # of deltas isn't tooo long, as otherwise read performance degrades.
+            if prev_group:
+                is_in_db = self._simple_select_one_onecol_txn(
+                    txn,
+                    table="state_groups",
+                    keyvalues={"id": prev_group},
+                    retcol="id",
+                    allow_none=True,
+                )
+                if not is_in_db:
+                    raise Exception(
+                        "Trying to persist state with unpersisted prev_group: %r"
+                        % (prev_group,)
+                    )
+
+                potential_hops = self._count_state_group_hops_txn(
+                    txn, prev_group
+                )
+            if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+                self._simple_insert_txn(
+                    txn,
+                    table="state_group_edges",
+                    values={
+                        "state_group": state_group,
+                        "prev_state_group": prev_group,
+                    },
+                )
+
+                self._simple_insert_many_txn(
+                    txn,
+                    table="state_groups_state",
+                    values=[
+                        {
+                            "state_group": state_group,
+                            "room_id": room_id,
+                            "type": key[0],
+                            "state_key": key[1],
+                            "event_id": state_id,
+                        }
+                        for key, state_id in delta_ids.iteritems()
+                    ],
+                )
+            else:
+                self._simple_insert_many_txn(
+                    txn,
+                    table="state_groups_state",
+                    values=[
+                        {
+                            "state_group": state_group,
+                            "room_id": room_id,
+                            "type": key[0],
+                            "state_key": key[1],
+                            "event_id": state_id,
+                        }
+                        for key, state_id in current_state_ids.iteritems()
+                    ],
+                )
+
+            # Prefill the state group cache with this group.
+            # It's fine to use the sequence like this as the state group map
+            # is immutable. (If the map wasn't immutable then this prefill could
+            # race with another update)
+            txn.call_after(
+                self._state_group_cache.update,
+                self._state_group_cache.sequence,
+                key=state_group,
+                value=dict(current_state_ids),
+                full=True,
+            )
+
+            return state_group
+
+        return self.runInteraction("store_state_group", _store_state_group_txn)
+
+    def _count_state_group_hops_txn(self, txn, state_group):
+        """Given a state group, count how many hops there are in the tree.
+
+        This is used to ensure the delta chains don't get too long.
+        """
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = ("""
+                WITH RECURSIVE state(state_group) AS (
+                    VALUES(?::bigint)
+                    UNION ALL
+                    SELECT prev_state_group FROM state_group_edges e, state s
+                    WHERE s.state_group = e.state_group
+                )
+                SELECT count(*) FROM state;
+            """)
+
+            txn.execute(sql, (state_group,))
+            row = txn.fetchone()
+            if row and row[0]:
+                return row[0]
+            else:
+                return 0
+        else:
+            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
+            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+            next_group = state_group
+            count = 0
+
+            while next_group:
+                next_group = self._simple_select_one_onecol_txn(
+                    txn,
+                    table="state_group_edges",
+                    keyvalues={"state_group": next_group},
+                    retcol="prev_state_group",
+                    allow_none=True,
+                )
+                if next_group:
+                    count += 1
+
+            return count
+
+
+class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
+    """ Keeps track of the state at a given event.
+
+    This is done by the concept of `state groups`. Every event is a assigned
+    a state group (identified by an arbitrary string), which references a
+    collection of state events. The current state of an event is then the
+    collection of state events referenced by the event's state group.
+
+    Hence, every change in the current state causes a new state group to be
+    generated. However, if no change happens (e.g., if we get a message event
+    with only one parent it inherits the state group from its parent.)
+
+    There are three tables:
+      * `state_groups`: Stores group name, first event with in the group and
+        room id.
+      * `event_to_state_groups`: Maps events to state groups.
+      * `state_groups_state`: Maps state group to state events.
+    """
+
+    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
+    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
+    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
+
+    def __init__(self, db_conn, hs):
+        super(StateStore, self).__init__(db_conn, hs)
+        self.register_background_update_handler(
+            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
+            self._background_deduplicate_state,
+        )
+        self.register_background_update_handler(
+            self.STATE_GROUP_INDEX_UPDATE_NAME,
+            self._background_index_state,
+        )
+        self.register_background_index_update(
+            self.CURRENT_STATE_INDEX_UPDATE_NAME,
+            index_name="current_state_events_member_index",
+            table="current_state_events",
+            columns=["state_key"],
+            where_clause="type='m.room.member'",
+        )
+
+    def _store_event_state_mappings_txn(self, txn, events_and_contexts):
+        state_groups = {}
+        for event, context in events_and_contexts:
+            if event.internal_metadata.is_outlier():
+                continue
+
+            # if the event was rejected, just give it the same state as its
+            # predecessor.
+            if context.rejected:
+                state_groups[event.event_id] = context.prev_group
+                continue
+
+            state_groups[event.event_id] = context.state_group
+
+        self._simple_insert_many_txn(
+            txn,
+            table="event_to_state_groups",
+            values=[
+                {
+                    "state_group": state_group_id,
+                    "event_id": event_id,
+                }
+                for event_id, state_group_id in state_groups.iteritems()
+            ],
+        )
+
+        for event_id, state_group_id in state_groups.iteritems():
+            txn.call_after(
+                self._get_state_group_for_event.prefill,
+                (event_id,), state_group_id
+            )
 
     @defer.inlineCallbacks
     def _background_deduplicate_state(self, progress, batch_size):
@@ -727,7 +884,7 @@ class StateStore(SQLBaseStore):
                         # of keys
 
                         delta_state = {
-                            key: value for key, value in curr_state.items()
+                            key: value for key, value in curr_state.iteritems()
                             if prev_state.get(key, None) != value
                         }
 
@@ -767,7 +924,7 @@ class StateStore(SQLBaseStore):
                                     "state_key": key[1],
                                     "event_id": state_id,
                                 }
-                                for key, state_id in delta_state.items()
+                                for key, state_id in delta_state.iteritems()
                             ],
                         )
 
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 200d124632..f0784ba137 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -35,15 +35,20 @@ what sort order was used:
 
 from twisted.internet import defer
 
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.events import EventsWorkerStore
+
 from synapse.util.caches.descriptors import cached
-from synapse.api.constants import EventTypes
 from synapse.types import RoomStreamToken
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
+import abc
 import logging
 
+from six.moves import range
+
 
 logger = logging.getLogger(__name__)
 
@@ -143,81 +148,41 @@ def filter_to_clause(event_filter):
     return " AND ".join(clauses), args
 
 
-class StreamStore(SQLBaseStore):
-    @defer.inlineCallbacks
-    def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
-        # NB this lives here instead of appservice.py so we can reuse the
-        # 'private' StreamToken class in this file.
-        if limit:
-            limit = max(limit, MAX_STREAM_SIZE)
-        else:
-            limit = MAX_STREAM_SIZE
-
-        # From and to keys should be integers from ordering.
-        from_id = RoomStreamToken.parse_stream_token(from_key)
-        to_id = RoomStreamToken.parse_stream_token(to_key)
-
-        if from_key == to_key:
-            defer.returnValue(([], to_key))
-            return
-
-        # select all the events between from/to with a sensible limit
-        sql = (
-            "SELECT e.event_id, e.room_id, e.type, s.state_key, "
-            "e.stream_ordering FROM events AS e "
-            "LEFT JOIN state_events as s ON "
-            "e.event_id = s.event_id "
-            "WHERE e.stream_ordering > ? AND e.stream_ordering <= ? "
-            "ORDER BY stream_ordering ASC LIMIT %(limit)d "
-        ) % {
-            "limit": limit
-        }
-
-        def f(txn):
-            # pull out all the events between the tokens
-            txn.execute(sql, (from_id.stream, to_id.stream,))
-            rows = self.cursor_to_dict(txn)
-
-            # Logic:
-            #  - We want ALL events which match the AS room_id regex
-            #  - We want ALL events which match the rooms represented by the AS
-            #    room_alias regex
-            #  - We want ALL events for rooms that AS users have joined.
-            # This is currently supported via get_app_service_rooms (which is
-            # used for the Notifier listener rooms). We can't reasonably make a
-            # SQL query for these room IDs, so we'll pull all the events between
-            # from/to and filter in python.
-            rooms_for_as = self._get_app_service_rooms_txn(txn, service)
-            room_ids_for_as = [r.room_id for r in rooms_for_as]
-
-            def app_service_interested(row):
-                if row["room_id"] in room_ids_for_as:
-                    return True
+class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
+    """This is an abstract base class where subclasses must implement
+    `get_room_max_stream_ordering` and `get_room_min_stream_ordering`
+    which can be called in the initializer.
+    """
 
-                if row["type"] == EventTypes.Member:
-                    if service.is_interested_in_user(row.get("state_key")):
-                        return True
-                return False
+    __metaclass__ = abc.ABCMeta
 
-            return [r for r in rows if app_service_interested(r)]
+    def __init__(self, db_conn, hs):
+        super(StreamWorkerStore, self).__init__(db_conn, hs)
 
-        rows = yield self.runInteraction("get_appservice_room_stream", f)
-
-        ret = yield self._get_events(
-            [r["event_id"] for r in rows],
-            get_prev_content=True
+        events_max = self.get_room_max_stream_ordering()
+        event_cache_prefill, min_event_val = self._get_cache_dict(
+            db_conn, "events",
+            entity_column="room_id",
+            stream_column="stream_ordering",
+            max_value=events_max,
+        )
+        self._events_stream_cache = StreamChangeCache(
+            "EventsRoomStreamChangeCache", min_event_val,
+            prefilled_cache=event_cache_prefill,
+        )
+        self._membership_stream_cache = StreamChangeCache(
+            "MembershipStreamChangeCache", events_max,
         )
 
-        self._set_before_and_after(ret, rows, topo_order=from_id is None)
+        self._stream_order_on_start = self.get_room_max_stream_ordering()
 
-        if rows:
-            key = "s%d" % max(r["stream_ordering"] for r in rows)
-        else:
-            # Assume we didn't get anything because there was nothing to
-            # get.
-            key = to_key
+    @abc.abstractmethod
+    def get_room_max_stream_ordering(self):
+        raise NotImplementedError()
 
-        defer.returnValue((ret, key))
+    @abc.abstractmethod
+    def get_room_min_stream_ordering(self):
+        raise NotImplementedError()
 
     @defer.inlineCallbacks
     def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0,
@@ -233,13 +198,14 @@ class StreamStore(SQLBaseStore):
 
         results = {}
         room_ids = list(room_ids)
-        for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
-            res = yield preserve_context_over_deferred(defer.gatherResults([
-                preserve_fn(self.get_room_events_stream_for_room)(
+        for rm_ids in (room_ids[i:i + 20] for i in range(0, len(room_ids), 20)):
+            res = yield make_deferred_yieldable(defer.gatherResults([
+                run_in_background(
+                    self.get_room_events_stream_for_room,
                     room_id, from_key, to_key, limit, order=order,
                 )
                 for room_id in rm_ids
-            ]))
+            ], consumeErrors=True))
             results.update(dict(zip(rm_ids, res)))
 
         defer.returnValue(results)
@@ -381,88 +347,6 @@ class StreamStore(SQLBaseStore):
         defer.returnValue(ret)
 
     @defer.inlineCallbacks
-    def paginate_room_events(self, room_id, from_key, to_key=None,
-                             direction='b', limit=-1, event_filter=None):
-        # Tokens really represent positions between elements, but we use
-        # the convention of pointing to the event before the gap. Hence
-        # we have a bit of asymmetry when it comes to equalities.
-        args = [False, room_id]
-        if direction == 'b':
-            order = "DESC"
-            bounds = upper_bound(
-                RoomStreamToken.parse(from_key), self.database_engine
-            )
-            if to_key:
-                bounds = "%s AND %s" % (bounds, lower_bound(
-                    RoomStreamToken.parse(to_key), self.database_engine
-                ))
-        else:
-            order = "ASC"
-            bounds = lower_bound(
-                RoomStreamToken.parse(from_key), self.database_engine
-            )
-            if to_key:
-                bounds = "%s AND %s" % (bounds, upper_bound(
-                    RoomStreamToken.parse(to_key), self.database_engine
-                ))
-
-        filter_clause, filter_args = filter_to_clause(event_filter)
-
-        if filter_clause:
-            bounds += " AND " + filter_clause
-            args.extend(filter_args)
-
-        if int(limit) > 0:
-            args.append(int(limit))
-            limit_str = " LIMIT ?"
-        else:
-            limit_str = ""
-
-        sql = (
-            "SELECT * FROM events"
-            " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
-            " ORDER BY topological_ordering %(order)s,"
-            " stream_ordering %(order)s %(limit)s"
-        ) % {
-            "bounds": bounds,
-            "order": order,
-            "limit": limit_str
-        }
-
-        def f(txn):
-            txn.execute(sql, args)
-
-            rows = self.cursor_to_dict(txn)
-
-            if rows:
-                topo = rows[-1]["topological_ordering"]
-                toke = rows[-1]["stream_ordering"]
-                if direction == 'b':
-                    # Tokens are positions between events.
-                    # This token points *after* the last event in the chunk.
-                    # We need it to point to the event before it in the chunk
-                    # when we are going backwards so we subtract one from the
-                    # stream part.
-                    toke -= 1
-                next_token = str(RoomStreamToken(topo, toke))
-            else:
-                # TODO (erikj): We should work out what to do here instead.
-                next_token = to_key if to_key else from_key
-
-            return rows, next_token,
-
-        rows, token = yield self.runInteraction("paginate_room_events", f)
-
-        events = yield self._get_events(
-            [r["event_id"] for r in rows],
-            get_prev_content=True
-        )
-
-        self._set_before_and_after(events, rows)
-
-        defer.returnValue((events, token))
-
-    @defer.inlineCallbacks
     def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None):
         rows, token = yield self.get_recent_event_ids_for_room(
             room_id, limit, end_token, from_token
@@ -534,6 +418,33 @@ class StreamStore(SQLBaseStore):
             "get_recent_events_for_room", get_recent_events_for_room_txn
         )
 
+    def get_room_event_after_stream_ordering(self, room_id, stream_ordering):
+        """Gets details of the first event in a room at or after a stream ordering
+
+        Args:
+            room_id (str):
+            stream_ordering (int):
+
+        Returns:
+            Deferred[(int, int, str)]:
+                (stream ordering, topological ordering, event_id)
+        """
+        def _f(txn):
+            sql = (
+                "SELECT stream_ordering, topological_ordering, event_id"
+                " FROM events"
+                " WHERE room_id = ? AND stream_ordering >= ?"
+                " AND NOT outlier"
+                " ORDER BY stream_ordering"
+                " LIMIT 1"
+            )
+            txn.execute(sql, (room_id, stream_ordering, ))
+            return txn.fetchone()
+
+        return self.runInteraction(
+            "get_room_event_after_stream_ordering", _f,
+        )
+
     @defer.inlineCallbacks
     def get_room_events_max_id(self, room_id=None):
         """Returns the current token for rooms stream.
@@ -542,7 +453,7 @@ class StreamStore(SQLBaseStore):
         `room_id` causes it to return the current room specific topological
         token.
         """
-        token = yield self._stream_id_gen.get_current_token()
+        token = yield self.get_room_max_stream_ordering()
         if room_id is None:
             defer.returnValue("s%d" % (token,))
         else:
@@ -552,12 +463,6 @@ class StreamStore(SQLBaseStore):
             )
             defer.returnValue("t%d-%d" % (topo, token))
 
-    def get_room_max_stream_ordering(self):
-        return self._stream_id_gen.get_current_token()
-
-    def get_room_min_stream_ordering(self):
-        return self._backfill_id_gen.get_current_token()
-
     def get_stream_token_for_event(self, event_id):
         """The stream token for an event
         Args:
@@ -829,3 +734,96 @@ class StreamStore(SQLBaseStore):
             updatevalues={"stream_id": stream_id},
             desc="update_federation_out_pos",
         )
+
+    def has_room_changed_since(self, room_id, stream_id):
+        return self._events_stream_cache.has_entity_changed(room_id, stream_id)
+
+
+class StreamStore(StreamWorkerStore):
+    def get_room_max_stream_ordering(self):
+        return self._stream_id_gen.get_current_token()
+
+    def get_room_min_stream_ordering(self):
+        return self._backfill_id_gen.get_current_token()
+
+    @defer.inlineCallbacks
+    def paginate_room_events(self, room_id, from_key, to_key=None,
+                             direction='b', limit=-1, event_filter=None):
+        # Tokens really represent positions between elements, but we use
+        # the convention of pointing to the event before the gap. Hence
+        # we have a bit of asymmetry when it comes to equalities.
+        args = [False, room_id]
+        if direction == 'b':
+            order = "DESC"
+            bounds = upper_bound(
+                RoomStreamToken.parse(from_key), self.database_engine
+            )
+            if to_key:
+                bounds = "%s AND %s" % (bounds, lower_bound(
+                    RoomStreamToken.parse(to_key), self.database_engine
+                ))
+        else:
+            order = "ASC"
+            bounds = lower_bound(
+                RoomStreamToken.parse(from_key), self.database_engine
+            )
+            if to_key:
+                bounds = "%s AND %s" % (bounds, upper_bound(
+                    RoomStreamToken.parse(to_key), self.database_engine
+                ))
+
+        filter_clause, filter_args = filter_to_clause(event_filter)
+
+        if filter_clause:
+            bounds += " AND " + filter_clause
+            args.extend(filter_args)
+
+        if int(limit) > 0:
+            args.append(int(limit))
+            limit_str = " LIMIT ?"
+        else:
+            limit_str = ""
+
+        sql = (
+            "SELECT * FROM events"
+            " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
+            " ORDER BY topological_ordering %(order)s,"
+            " stream_ordering %(order)s %(limit)s"
+        ) % {
+            "bounds": bounds,
+            "order": order,
+            "limit": limit_str
+        }
+
+        def f(txn):
+            txn.execute(sql, args)
+
+            rows = self.cursor_to_dict(txn)
+
+            if rows:
+                topo = rows[-1]["topological_ordering"]
+                toke = rows[-1]["stream_ordering"]
+                if direction == 'b':
+                    # Tokens are positions between events.
+                    # This token points *after* the last event in the chunk.
+                    # We need it to point to the event before it in the chunk
+                    # when we are going backwards so we subtract one from the
+                    # stream part.
+                    toke -= 1
+                next_token = str(RoomStreamToken(topo, toke))
+            else:
+                # TODO (erikj): We should work out what to do here instead.
+                next_token = to_key if to_key else from_key
+
+            return rows, next_token,
+
+        rows, token = yield self.runInteraction("paginate_room_events", f)
+
+        events = yield self._get_events(
+            [r["event_id"] for r in rows],
+            get_prev_content=True
+        )
+
+        self._set_before_and_after(events, rows)
+
+        defer.returnValue((events, token))
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index 5a2c1aa59b..6671d3cfca 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,25 +14,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
+from synapse.storage.account_data import AccountDataWorkerStore
+
 from synapse.util.caches.descriptors import cached
 from twisted.internet import defer
 
-import ujson as json
+import simplejson as json
 import logging
 
-logger = logging.getLogger(__name__)
+from six.moves import range
 
+logger = logging.getLogger(__name__)
 
-class TagsStore(SQLBaseStore):
-    def get_max_account_data_stream_id(self):
-        """Get the current max stream id for the private user data stream
-
-        Returns:
-            A deferred int.
-        """
-        return self._account_data_id_gen.get_current_token()
 
+class TagsWorkerStore(AccountDataWorkerStore):
     @cached()
     def get_tags_for_user(self, user_id):
         """Get all the tags for a user.
@@ -95,7 +91,7 @@ class TagsStore(SQLBaseStore):
             for stream_id, user_id, room_id in tag_ids:
                 txn.execute(sql, (user_id, room_id))
                 tags = []
-                for tag, content in txn.fetchall():
+                for tag, content in txn:
                     tags.append(json.dumps(tag) + ":" + content)
                 tag_json = "{" + ",".join(tags) + "}"
                 results.append((stream_id, user_id, room_id, tag_json))
@@ -104,7 +100,7 @@ class TagsStore(SQLBaseStore):
 
         batch_size = 50
         results = []
-        for i in xrange(0, len(tag_ids), batch_size):
+        for i in range(0, len(tag_ids), batch_size):
             tags = yield self.runInteraction(
                 "get_all_updated_tag_content",
                 get_tag_content,
@@ -132,7 +128,7 @@ class TagsStore(SQLBaseStore):
                 " WHERE user_id = ? AND stream_id > ?"
             )
             txn.execute(sql, (user_id, stream_id))
-            room_ids = [row[0] for row in txn.fetchall()]
+            room_ids = [row[0] for row in txn]
             return room_ids
 
         changed = self._account_data_stream_cache.has_entity_changed(
@@ -170,6 +166,8 @@ class TagsStore(SQLBaseStore):
             row["tag"]: json.loads(row["content"]) for row in rows
         })
 
+
+class TagsStore(TagsWorkerStore):
     @defer.inlineCallbacks
     def add_tag_to_room(self, user_id, room_id, tag, content):
         """Add a tag to a room for a user.
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 809fdd311f..f825264ea9 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -23,7 +23,7 @@ from canonicaljson import encode_canonical_json
 from collections import namedtuple
 
 import logging
-import ujson as json
+import simplejson as json
 
 logger = logging.getLogger(__name__)
 
@@ -46,8 +46,8 @@ class TransactionStore(SQLBaseStore):
     """A collection of queries for handling PDUs.
     """
 
-    def __init__(self, hs):
-        super(TransactionStore, self).__init__(hs)
+    def __init__(self, db_conn, hs):
+        super(TransactionStore, self).__init__(db_conn, hs)
 
         self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
 
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
new file mode 100644
index 0000000000..d6e289ffbe
--- /dev/null
+++ b/synapse/storage/user_directory.py
@@ -0,0 +1,764 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from ._base import SQLBaseStore
+
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.api.constants import EventTypes, JoinRules
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import get_domain_from_id, get_localpart_from_id
+
+import re
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class UserDirectoryStore(SQLBaseStore):
+    @cachedInlineCallbacks(cache_context=True)
+    def is_room_world_readable_or_publicly_joinable(self, room_id, cache_context):
+        """Check if the room is either world_readable or publically joinable
+        """
+        current_state_ids = yield self.get_current_state_ids(
+            room_id, on_invalidate=cache_context.invalidate
+        )
+
+        join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
+        if join_rules_id:
+            join_rule_ev = yield self.get_event(join_rules_id, allow_none=True)
+            if join_rule_ev:
+                if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
+                    defer.returnValue(True)
+
+        hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
+        if hist_vis_id:
+            hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True)
+            if hist_vis_ev:
+                if hist_vis_ev.content.get("history_visibility") == "world_readable":
+                    defer.returnValue(True)
+
+        defer.returnValue(False)
+
+    @defer.inlineCallbacks
+    def add_users_to_public_room(self, room_id, user_ids):
+        """Add user to the list of users in public rooms
+
+        Args:
+            room_id (str): A room_id that all users are in that is world_readable
+                or publically joinable
+            user_ids (list(str)): Users to add
+        """
+        yield self._simple_insert_many(
+            table="users_in_public_rooms",
+            values=[
+                {
+                    "user_id": user_id,
+                    "room_id": room_id,
+                }
+                for user_id in user_ids
+            ],
+            desc="add_users_to_public_room"
+        )
+        for user_id in user_ids:
+            self.get_user_in_public_room.invalidate((user_id,))
+
+    def add_profiles_to_user_dir(self, room_id, users_with_profile):
+        """Add profiles to the user directory
+
+        Args:
+            room_id (str): A room_id that all users are joined to
+            users_with_profile (dict): Users to add to directory in the form of
+                mapping of user_id -> ProfileInfo
+        """
+        if isinstance(self.database_engine, PostgresEngine):
+            # We weight the loclpart most highly, then display name and finally
+            # server name
+            sql = """
+                INSERT INTO user_directory_search(user_id, vector)
+                VALUES (?,
+                    setweight(to_tsvector('english', ?), 'A')
+                    || setweight(to_tsvector('english', ?), 'D')
+                    || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+                )
+            """
+            args = (
+                (
+                    user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id),
+                    profile.display_name,
+                )
+                for user_id, profile in users_with_profile.iteritems()
+            )
+        elif isinstance(self.database_engine, Sqlite3Engine):
+            sql = """
+                INSERT INTO user_directory_search(user_id, value)
+                VALUES (?,?)
+            """
+            args = (
+                (
+                    user_id,
+                    "%s %s" % (user_id, p.display_name,) if p.display_name else user_id
+                )
+                for user_id, p in users_with_profile.iteritems()
+            )
+        else:
+            # This should be unreachable.
+            raise Exception("Unrecognized database engine")
+
+        def _add_profiles_to_user_dir_txn(txn):
+            txn.executemany(sql, args)
+            self._simple_insert_many_txn(
+                txn,
+                table="user_directory",
+                values=[
+                    {
+                        "user_id": user_id,
+                        "room_id": room_id,
+                        "display_name": profile.display_name,
+                        "avatar_url": profile.avatar_url,
+                    }
+                    for user_id, profile in users_with_profile.iteritems()
+                ]
+            )
+            for user_id in users_with_profile:
+                txn.call_after(
+                    self.get_user_in_directory.invalidate, (user_id,)
+                )
+
+        return self.runInteraction(
+            "add_profiles_to_user_dir", _add_profiles_to_user_dir_txn
+        )
+
+    @defer.inlineCallbacks
+    def update_user_in_user_dir(self, user_id, room_id):
+        yield self._simple_update_one(
+            table="user_directory",
+            keyvalues={"user_id": user_id},
+            updatevalues={"room_id": room_id},
+            desc="update_user_in_user_dir",
+        )
+        self.get_user_in_directory.invalidate((user_id,))
+
+    def update_profile_in_user_dir(self, user_id, display_name, avatar_url, room_id):
+        def _update_profile_in_user_dir_txn(txn):
+            new_entry = self._simple_upsert_txn(
+                txn,
+                table="user_directory",
+                keyvalues={"user_id": user_id},
+                insertion_values={"room_id": room_id},
+                values={"display_name": display_name, "avatar_url": avatar_url},
+                lock=False,  # We're only inserter
+            )
+
+            if isinstance(self.database_engine, PostgresEngine):
+                # We weight the localpart most highly, then display name and finally
+                # server name
+                if new_entry:
+                    sql = """
+                        INSERT INTO user_directory_search(user_id, vector)
+                        VALUES (?,
+                            setweight(to_tsvector('english', ?), 'A')
+                            || setweight(to_tsvector('english', ?), 'D')
+                            || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+                        )
+                    """
+                    txn.execute(
+                        sql,
+                        (
+                            user_id, get_localpart_from_id(user_id),
+                            get_domain_from_id(user_id), display_name,
+                        )
+                    )
+                else:
+                    sql = """
+                        UPDATE user_directory_search
+                        SET vector = setweight(to_tsvector('english', ?), 'A')
+                            || setweight(to_tsvector('english', ?), 'D')
+                            || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+                        WHERE user_id = ?
+                    """
+                    txn.execute(
+                        sql,
+                        (
+                            get_localpart_from_id(user_id), get_domain_from_id(user_id),
+                            display_name, user_id,
+                        )
+                    )
+            elif isinstance(self.database_engine, Sqlite3Engine):
+                value = "%s %s" % (user_id, display_name,) if display_name else user_id
+                self._simple_upsert_txn(
+                    txn,
+                    table="user_directory_search",
+                    keyvalues={"user_id": user_id},
+                    values={"value": value},
+                    lock=False,  # We're only inserter
+                )
+            else:
+                # This should be unreachable.
+                raise Exception("Unrecognized database engine")
+
+            txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
+
+        return self.runInteraction(
+            "update_profile_in_user_dir", _update_profile_in_user_dir_txn
+        )
+
+    @defer.inlineCallbacks
+    def update_user_in_public_user_list(self, user_id, room_id):
+        yield self._simple_update_one(
+            table="users_in_public_rooms",
+            keyvalues={"user_id": user_id},
+            updatevalues={"room_id": room_id},
+            desc="update_user_in_public_user_list",
+        )
+        self.get_user_in_public_room.invalidate((user_id,))
+
+    def remove_from_user_dir(self, user_id):
+        def _remove_from_user_dir_txn(txn):
+            self._simple_delete_txn(
+                txn,
+                table="user_directory",
+                keyvalues={"user_id": user_id},
+            )
+            self._simple_delete_txn(
+                txn,
+                table="user_directory_search",
+                keyvalues={"user_id": user_id},
+            )
+            self._simple_delete_txn(
+                txn,
+                table="users_in_public_rooms",
+                keyvalues={"user_id": user_id},
+            )
+            txn.call_after(
+                self.get_user_in_directory.invalidate, (user_id,)
+            )
+            txn.call_after(
+                self.get_user_in_public_room.invalidate, (user_id,)
+            )
+        return self.runInteraction(
+            "remove_from_user_dir", _remove_from_user_dir_txn,
+        )
+
+    @defer.inlineCallbacks
+    def remove_from_user_in_public_room(self, user_id):
+        yield self._simple_delete(
+            table="users_in_public_rooms",
+            keyvalues={"user_id": user_id},
+            desc="remove_from_user_in_public_room",
+        )
+        self.get_user_in_public_room.invalidate((user_id,))
+
+    def get_users_in_public_due_to_room(self, room_id):
+        """Get all user_ids that are in the room directory becuase they're
+        in the given room_id
+        """
+        return self._simple_select_onecol(
+            table="users_in_public_rooms",
+            keyvalues={"room_id": room_id},
+            retcol="user_id",
+            desc="get_users_in_public_due_to_room",
+        )
+
+    @defer.inlineCallbacks
+    def get_users_in_dir_due_to_room(self, room_id):
+        """Get all user_ids that are in the room directory becuase they're
+        in the given room_id
+        """
+        user_ids_dir = yield self._simple_select_onecol(
+            table="user_directory",
+            keyvalues={"room_id": room_id},
+            retcol="user_id",
+            desc="get_users_in_dir_due_to_room",
+        )
+
+        user_ids_pub = yield self._simple_select_onecol(
+            table="users_in_public_rooms",
+            keyvalues={"room_id": room_id},
+            retcol="user_id",
+            desc="get_users_in_dir_due_to_room",
+        )
+
+        user_ids_share = yield self._simple_select_onecol(
+            table="users_who_share_rooms",
+            keyvalues={"room_id": room_id},
+            retcol="user_id",
+            desc="get_users_in_dir_due_to_room",
+        )
+
+        user_ids = set(user_ids_dir)
+        user_ids.update(user_ids_pub)
+        user_ids.update(user_ids_share)
+
+        defer.returnValue(user_ids)
+
+    @defer.inlineCallbacks
+    def get_all_rooms(self):
+        """Get all room_ids we've ever known about, in ascending order of "size"
+        """
+        sql = """
+            SELECT room_id FROM current_state_events
+            GROUP BY room_id
+            ORDER BY count(*) ASC
+        """
+        rows = yield self._execute("get_all_rooms", None, sql)
+        defer.returnValue([room_id for room_id, in rows])
+
+    @defer.inlineCallbacks
+    def get_all_local_users(self):
+        """Get all local users
+        """
+        sql = """
+            SELECT name FROM users
+        """
+        rows = yield self._execute("get_all_local_users", None, sql)
+        defer.returnValue([name for name, in rows])
+
+    def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
+        """Insert entries into the users_who_share_rooms table. The first
+        user should be a local user.
+
+        Args:
+            room_id (str)
+            share_private (bool): Is the room private
+            user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+        """
+        def _add_users_who_share_room_txn(txn):
+            self._simple_insert_many_txn(
+                txn,
+                table="users_who_share_rooms",
+                values=[
+                    {
+                        "user_id": user_id,
+                        "other_user_id": other_user_id,
+                        "room_id": room_id,
+                        "share_private": share_private,
+                    }
+                    for user_id, other_user_id in user_id_tuples
+                ],
+            )
+            for user_id, other_user_id in user_id_tuples:
+                txn.call_after(
+                    self.get_users_who_share_room_from_dir.invalidate,
+                    (user_id,),
+                )
+                txn.call_after(
+                    self.get_if_users_share_a_room.invalidate,
+                    (user_id, other_user_id),
+                )
+        return self.runInteraction(
+            "add_users_who_share_room", _add_users_who_share_room_txn
+        )
+
+    def update_users_who_share_room(self, room_id, share_private, user_id_sets):
+        """Updates entries in the users_who_share_rooms table. The first
+        user should be a local user.
+
+        Args:
+            room_id (str)
+            share_private (bool): Is the room private
+            user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+        """
+        def _update_users_who_share_room_txn(txn):
+            sql = """
+                UPDATE users_who_share_rooms
+                SET room_id = ?, share_private = ?
+                WHERE user_id = ? AND other_user_id = ?
+            """
+            txn.executemany(
+                sql,
+                (
+                    (room_id, share_private, uid, oid)
+                    for uid, oid in user_id_sets
+                )
+            )
+            for user_id, other_user_id in user_id_sets:
+                txn.call_after(
+                    self.get_users_who_share_room_from_dir.invalidate,
+                    (user_id,),
+                )
+                txn.call_after(
+                    self.get_if_users_share_a_room.invalidate,
+                    (user_id, other_user_id),
+                )
+        return self.runInteraction(
+            "update_users_who_share_room", _update_users_who_share_room_txn
+        )
+
+    def remove_user_who_share_room(self, user_id, other_user_id):
+        """Deletes entries in the users_who_share_rooms table. The first
+        user should be a local user.
+
+        Args:
+            room_id (str)
+            share_private (bool): Is the room private
+            user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+        """
+        def _remove_user_who_share_room_txn(txn):
+            self._simple_delete_txn(
+                txn,
+                table="users_who_share_rooms",
+                keyvalues={
+                    "user_id": user_id,
+                    "other_user_id": other_user_id,
+                },
+            )
+            txn.call_after(
+                self.get_users_who_share_room_from_dir.invalidate,
+                (user_id,),
+            )
+            txn.call_after(
+                self.get_if_users_share_a_room.invalidate,
+                (user_id, other_user_id),
+            )
+
+        return self.runInteraction(
+            "remove_user_who_share_room", _remove_user_who_share_room_txn
+        )
+
+    @cached(max_entries=500000)
+    def get_if_users_share_a_room(self, user_id, other_user_id):
+        """Gets if users share a room.
+
+        Args:
+            user_id (str): Must be a local user_id
+            other_user_id (str)
+
+        Returns:
+            bool|None: None if they don't share a room, otherwise whether they
+            share a private room or not.
+        """
+        return self._simple_select_one_onecol(
+            table="users_who_share_rooms",
+            keyvalues={
+                "user_id": user_id,
+                "other_user_id": other_user_id,
+            },
+            retcol="share_private",
+            allow_none=True,
+            desc="get_if_users_share_a_room",
+        )
+
+    @cachedInlineCallbacks(max_entries=500000, iterable=True)
+    def get_users_who_share_room_from_dir(self, user_id):
+        """Returns the set of users who share a room with `user_id`
+
+        Args:
+            user_id(str): Must be a local user
+
+        Returns:
+            dict: user_id -> share_private mapping
+        """
+        rows = yield self._simple_select_list(
+            table="users_who_share_rooms",
+            keyvalues={
+                "user_id": user_id,
+            },
+            retcols=("other_user_id", "share_private",),
+            desc="get_users_who_share_room_with_user",
+        )
+
+        defer.returnValue({
+            row["other_user_id"]: row["share_private"]
+            for row in rows
+        })
+
+    def get_users_in_share_dir_with_room_id(self, user_id, room_id):
+        """Get all user tuples that are in the users_who_share_rooms due to the
+        given room_id.
+
+        Returns:
+            [(user_id, other_user_id)]: where one of the two will match the given
+            user_id.
+        """
+        sql = """
+            SELECT user_id, other_user_id FROM users_who_share_rooms
+            WHERE room_id = ? AND (user_id = ? OR other_user_id = ?)
+        """
+        return self._execute(
+            "get_users_in_share_dir_with_room_id", None, sql, room_id, user_id, user_id
+        )
+
+    @defer.inlineCallbacks
+    def get_rooms_in_common_for_users(self, user_id, other_user_id):
+        """Given two user_ids find out the list of rooms they share.
+        """
+        sql = """
+            SELECT room_id FROM (
+                SELECT c.room_id FROM current_state_events AS c
+                INNER JOIN room_memberships USING (event_id)
+                WHERE type = 'm.room.member'
+                    AND membership = 'join'
+                    AND state_key = ?
+            ) AS f1 INNER JOIN (
+                SELECT c.room_id FROM current_state_events AS c
+                INNER JOIN room_memberships USING (event_id)
+                WHERE type = 'm.room.member'
+                    AND membership = 'join'
+                    AND state_key = ?
+            ) f2 USING (room_id)
+        """
+
+        rows = yield self._execute(
+            "get_rooms_in_common_for_users", None, sql, user_id, other_user_id
+        )
+
+        defer.returnValue([room_id for room_id, in rows])
+
+    def delete_all_from_user_dir(self):
+        """Delete the entire user directory
+        """
+        def _delete_all_from_user_dir_txn(txn):
+            txn.execute("DELETE FROM user_directory")
+            txn.execute("DELETE FROM user_directory_search")
+            txn.execute("DELETE FROM users_in_public_rooms")
+            txn.execute("DELETE FROM users_who_share_rooms")
+            txn.call_after(self.get_user_in_directory.invalidate_all)
+            txn.call_after(self.get_user_in_public_room.invalidate_all)
+            txn.call_after(self.get_users_who_share_room_from_dir.invalidate_all)
+            txn.call_after(self.get_if_users_share_a_room.invalidate_all)
+        return self.runInteraction(
+            "delete_all_from_user_dir", _delete_all_from_user_dir_txn
+        )
+
+    @cached()
+    def get_user_in_directory(self, user_id):
+        return self._simple_select_one(
+            table="user_directory",
+            keyvalues={"user_id": user_id},
+            retcols=("room_id", "display_name", "avatar_url",),
+            allow_none=True,
+            desc="get_user_in_directory",
+        )
+
+    @cached()
+    def get_user_in_public_room(self, user_id):
+        return self._simple_select_one(
+            table="users_in_public_rooms",
+            keyvalues={"user_id": user_id},
+            retcols=("room_id",),
+            allow_none=True,
+            desc="get_user_in_public_room",
+        )
+
+    def get_user_directory_stream_pos(self):
+        return self._simple_select_one_onecol(
+            table="user_directory_stream_pos",
+            keyvalues={},
+            retcol="stream_id",
+            desc="get_user_directory_stream_pos",
+        )
+
+    def update_user_directory_stream_pos(self, stream_id):
+        return self._simple_update_one(
+            table="user_directory_stream_pos",
+            keyvalues={},
+            updatevalues={"stream_id": stream_id},
+            desc="update_user_directory_stream_pos",
+        )
+
+    def get_current_state_deltas(self, prev_stream_id):
+        prev_stream_id = int(prev_stream_id)
+        if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id):
+            return []
+
+        def get_current_state_deltas_txn(txn):
+            # First we calculate the max stream id that will give us less than
+            # N results.
+            # We arbitarily limit to 100 stream_id entries to ensure we don't
+            # select toooo many.
+            sql = """
+                SELECT stream_id, count(*)
+                FROM current_state_delta_stream
+                WHERE stream_id > ?
+                GROUP BY stream_id
+                ORDER BY stream_id ASC
+                LIMIT 100
+            """
+            txn.execute(sql, (prev_stream_id,))
+
+            total = 0
+            max_stream_id = prev_stream_id
+            for max_stream_id, count in txn:
+                total += count
+                if total > 100:
+                    # We arbitarily limit to 100 entries to ensure we don't
+                    # select toooo many.
+                    break
+
+            # Now actually get the deltas
+            sql = """
+                SELECT stream_id, room_id, type, state_key, event_id, prev_event_id
+                FROM current_state_delta_stream
+                WHERE ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC
+            """
+            txn.execute(sql, (prev_stream_id, max_stream_id,))
+            return self.cursor_to_dict(txn)
+
+        return self.runInteraction(
+            "get_current_state_deltas", get_current_state_deltas_txn
+        )
+
+    def get_max_stream_id_in_current_state_deltas(self):
+        return self._simple_select_one_onecol(
+            table="current_state_delta_stream",
+            keyvalues={},
+            retcol="COALESCE(MAX(stream_id), -1)",
+            desc="get_max_stream_id_in_current_state_deltas",
+        )
+
+    @defer.inlineCallbacks
+    def search_user_dir(self, user_id, search_term, limit):
+        """Searches for users in directory
+
+        Returns:
+            dict of the form::
+
+                {
+                    "limited": <bool>,  # whether there were more results or not
+                    "results": [  # Ordered by best match first
+                        {
+                            "user_id": <user_id>,
+                            "display_name": <display_name>,
+                            "avatar_url": <avatar_url>
+                        }
+                    ]
+                }
+        """
+
+        if self.hs.config.user_directory_search_all_users:
+            # make s.user_id null to keep the ordering algorithm happy
+            join_clause = """
+                CROSS JOIN (SELECT NULL as user_id) AS s
+            """
+            join_args = ()
+            where_clause = "1=1"
+        else:
+            join_clause = """
+                LEFT JOIN users_in_public_rooms AS p USING (user_id)
+                LEFT JOIN (
+                    SELECT other_user_id AS user_id FROM users_who_share_rooms
+                    WHERE user_id = ? AND share_private
+                ) AS s USING (user_id)
+            """
+            join_args = (user_id,)
+            where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
+
+        if isinstance(self.database_engine, PostgresEngine):
+            full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
+
+            # We order by rank and then if they have profile info
+            # The ranking algorithm is hand tweaked for "best" results. Broadly
+            # the idea is we give a higher weight to exact matches.
+            # The array of numbers are the weights for the various part of the
+            # search: (domain, _, display name, localpart)
+            sql = """
+                SELECT d.user_id AS user_id, display_name, avatar_url
+                FROM user_directory_search
+                INNER JOIN user_directory AS d USING (user_id)
+                %s
+                WHERE
+                    %s
+                    AND vector @@ to_tsquery('english', ?)
+                ORDER BY
+                    (CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
+                    * (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END)
+                    * (CASE WHEN avatar_url IS NOT NULL THEN 1.2 ELSE 1.0 END)
+                    * (
+                        3 * ts_rank_cd(
+                            '{0.1, 0.1, 0.9, 1.0}',
+                            vector,
+                            to_tsquery('english', ?),
+                            8
+                        )
+                        + ts_rank_cd(
+                            '{0.1, 0.1, 0.9, 1.0}',
+                            vector,
+                            to_tsquery('english', ?),
+                            8
+                        )
+                    )
+                    DESC,
+                    display_name IS NULL,
+                    avatar_url IS NULL
+                LIMIT ?
+            """ % (join_clause, where_clause)
+            args = join_args + (full_query, exact_query, prefix_query, limit + 1,)
+        elif isinstance(self.database_engine, Sqlite3Engine):
+            search_query = _parse_query_sqlite(search_term)
+
+            sql = """
+                SELECT d.user_id AS user_id, display_name, avatar_url
+                FROM user_directory_search
+                INNER JOIN user_directory AS d USING (user_id)
+                %s
+                WHERE
+                    %s
+                    AND value MATCH ?
+                ORDER BY
+                    rank(matchinfo(user_directory_search)) DESC,
+                    display_name IS NULL,
+                    avatar_url IS NULL
+                LIMIT ?
+            """ % (join_clause, where_clause)
+            args = join_args + (search_query, limit + 1)
+        else:
+            # This should be unreachable.
+            raise Exception("Unrecognized database engine")
+
+        results = yield self._execute(
+            "search_user_dir", self.cursor_to_dict, sql, *args
+        )
+
+        limited = len(results) > limit
+
+        defer.returnValue({
+            "limited": limited,
+            "results": results,
+        })
+
+
+def _parse_query_sqlite(search_term):
+    """Takes a plain unicode string from the user and converts it into a form
+    that can be passed to database.
+    We use this so that we can add prefix matching, which isn't something
+    that is supported by default.
+
+    We specifically add both a prefix and non prefix matching term so that
+    exact matches get ranked higher.
+    """
+
+    # Pull out the individual words, discarding any non-word characters.
+    results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+    return " & ".join("(%s* OR %s)" % (result, result,) for result in results)
+
+
+def _parse_query_postgres(search_term):
+    """Takes a plain unicode string from the user and converts it into a form
+    that can be passed to database.
+    We use this so that we can add prefix matching, which isn't something
+    that is supported by default.
+    """
+
+    # Pull out the individual words, discarding any non-word characters.
+    results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+
+    both = " & ".join("(%s:* | %s)" % (result, result,) for result in results)
+    exact = " & ".join("%s" % (result,) for result in results)
+    prefix = " & ".join("%s:*" % (result,) for result in results)
+
+    return both, exact, prefix
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 46cf93ff87..95031dc9ec 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -30,6 +30,17 @@ class IdGenerator(object):
 
 
 def _load_current_id(db_conn, table, column, step=1):
+    """
+
+    Args:
+        db_conn (object):
+        table (str):
+        column (str):
+        step (int):
+
+    Returns:
+        int
+    """
     cur = db_conn.cursor()
     if step == 1:
         cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
@@ -131,6 +142,9 @@ class StreamIdGenerator(object):
     def get_current_token(self):
         """Returns the maximum stream id such that all stream ids less than or
         equal to it have been successfully persisted.
+
+        Returns:
+            int
         """
         with self._lock:
             if self._unfinished_ids:
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 4f089bfb94..ca78e551cb 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -80,13 +80,13 @@ class PaginationConfig(object):
                 from_tok = None  # For backwards compat.
             elif from_tok:
                 from_tok = StreamToken.from_string(from_tok)
-        except:
+        except Exception:
             raise SynapseError(400, "'from' paramater is invalid")
 
         try:
             if to_tok:
                 to_tok = StreamToken.from_string(to_tok)
-        except:
+        except Exception:
             raise SynapseError(400, "'to' paramater is invalid")
 
         limit = get_param("limit", None)
@@ -98,7 +98,7 @@ class PaginationConfig(object):
 
         try:
             return PaginationConfig(from_tok, to_tok, direction, limit)
-        except:
+        except Exception:
             logger.exception("Failed to create pagination config")
             raise SynapseError(400, "Invalid request.")
 
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 91a59b0bae..f03ad99118 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -45,6 +45,7 @@ class EventSources(object):
         push_rules_key, _ = self.store.get_push_rules_stream_token()
         to_device_key = self.store.get_to_device_stream_token()
         device_list_key = self.store.get_device_stream_token()
+        groups_key = self.store.get_group_stream_token()
 
         token = StreamToken(
             room_key=(
@@ -65,6 +66,7 @@ class EventSources(object):
             push_rules_key=push_rules_key,
             to_device_key=to_device_key,
             device_list_key=device_list_key,
+            groups_key=groups_key,
         )
         defer.returnValue(token)
 
@@ -73,6 +75,7 @@ class EventSources(object):
         push_rules_key, _ = self.store.get_push_rules_stream_token()
         to_device_key = self.store.get_to_device_stream_token()
         device_list_key = self.store.get_device_stream_token()
+        groups_key = self.store.get_group_stream_token()
 
         token = StreamToken(
             room_key=(
@@ -93,5 +96,6 @@ class EventSources(object):
             push_rules_key=push_rules_key,
             to_device_key=to_device_key,
             device_list_key=device_list_key,
+            groups_key=groups_key,
         )
         defer.returnValue(token)
diff --git a/synapse/types.py b/synapse/types.py
index 9666f9d73f..cc7c182a78 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -12,26 +12,66 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import string
 
 from synapse.api.errors import SynapseError
 
 from collections import namedtuple
 
 
-Requester = namedtuple("Requester", [
+class Requester(namedtuple("Requester", [
     "user", "access_token_id", "is_guest", "device_id", "app_service",
-])
-"""
-Represents the user making a request
+])):
+    """
+    Represents the user making a request
 
-Attributes:
-    user (UserID):  id of the user making the request
-    access_token_id (int|None):  *ID* of the access token used for this
-        request, or None if it came via the appservice API or similar
-    is_guest (bool):  True if the user making this request is a guest user
-    device_id (str|None):  device_id which was set at authentication time
-    app_service (ApplicationService|None):  the AS requesting on behalf of the user
-"""
+    Attributes:
+        user (UserID):  id of the user making the request
+        access_token_id (int|None):  *ID* of the access token used for this
+            request, or None if it came via the appservice API or similar
+        is_guest (bool):  True if the user making this request is a guest user
+        device_id (str|None):  device_id which was set at authentication time
+        app_service (ApplicationService|None):  the AS requesting on behalf of the user
+    """
+
+    def serialize(self):
+        """Converts self to a type that can be serialized as JSON, and then
+        deserialized by `deserialize`
+
+        Returns:
+            dict
+        """
+        return {
+            "user_id": self.user.to_string(),
+            "access_token_id": self.access_token_id,
+            "is_guest": self.is_guest,
+            "device_id": self.device_id,
+            "app_server_id": self.app_service.id if self.app_service else None,
+        }
+
+    @staticmethod
+    def deserialize(store, input):
+        """Converts a dict that was produced by `serialize` back into a
+        Requester.
+
+        Args:
+            store (DataStore): Used to convert AS ID to AS object
+            input (dict): A dict produced by `serialize`
+
+        Returns:
+            Requester
+        """
+        appservice = None
+        if input["app_server_id"]:
+            appservice = store.get_app_service_by_id(input["app_server_id"])
+
+        return Requester(
+            user=UserID.from_string(input["user_id"]),
+            access_token_id=input["access_token_id"],
+            is_guest=input["is_guest"],
+            device_id=input["device_id"],
+            app_service=appservice,
+        )
 
 
 def create_requester(user_id, access_token_id=None, is_guest=False,
@@ -56,10 +96,17 @@ def create_requester(user_id, access_token_id=None, is_guest=False,
 
 
 def get_domain_from_id(string):
-    try:
-        return string.split(":", 1)[1]
-    except IndexError:
+    idx = string.find(":")
+    if idx == -1:
+        raise SynapseError(400, "Invalid ID: %r" % (string,))
+    return string[idx + 1:]
+
+
+def get_localpart_from_id(string):
+    idx = string.find(":")
+    if idx == -1:
         raise SynapseError(400, "Invalid ID: %r" % (string,))
+    return string[1:idx]
 
 
 class DomainSpecificString(
@@ -119,14 +166,10 @@ class DomainSpecificString(
         try:
             cls.from_string(s)
             return True
-        except:
+        except Exception:
             return False
 
-    __str__ = to_string
-
-    @classmethod
-    def create(cls, localpart, domain,):
-        return cls(localpart=localpart, domain=domain)
+    __repr__ = to_string
 
 
 class UserID(DomainSpecificString):
@@ -149,6 +192,43 @@ class EventID(DomainSpecificString):
     SIGIL = "$"
 
 
+class GroupID(DomainSpecificString):
+    """Structure representing a group ID."""
+    SIGIL = "+"
+
+    @classmethod
+    def from_string(cls, s):
+        group_id = super(GroupID, cls).from_string(s)
+        if not group_id.localpart:
+            raise SynapseError(
+                400,
+                "Group ID cannot be empty",
+            )
+
+        if contains_invalid_mxid_characters(group_id.localpart):
+            raise SynapseError(
+                400,
+                "Group ID can only contain characters a-z, 0-9, or '=_-./'",
+            )
+
+        return group_id
+
+
+mxid_localpart_allowed_characters = set("_-./=" + string.ascii_lowercase + string.digits)
+
+
+def contains_invalid_mxid_characters(localpart):
+    """Check for characters not allowed in an mxid or groupid localpart
+
+    Args:
+        localpart (basestring): the localpart to be checked
+
+    Returns:
+        bool: True if there are any naughty characters
+    """
+    return any(c not in mxid_localpart_allowed_characters for c in localpart)
+
+
 class StreamToken(
     namedtuple("Token", (
         "room_key",
@@ -159,6 +239,7 @@ class StreamToken(
         "push_rules_key",
         "to_device_key",
         "device_list_key",
+        "groups_key",
     ))
 ):
     _SEPARATOR = "_"
@@ -171,7 +252,7 @@ class StreamToken(
                 # i.e. old token from before receipt_key
                 keys.append("0")
             return cls(*keys)
-        except:
+        except Exception:
             raise SynapseError(400, "Invalid Token")
 
     def to_string(self):
@@ -197,6 +278,7 @@ class StreamToken(
             or (int(other.push_rules_key) < int(self.push_rules_key))
             or (int(other.to_device_key) < int(self.to_device_key))
             or (int(other.device_list_key) < int(self.device_list_key))
+            or (int(other.groups_key) < int(self.groups_key))
         )
 
     def copy_and_advance(self, key, new_value):
@@ -216,9 +298,7 @@ class StreamToken(
             return self
 
     def copy_and_replace(self, key, new_value):
-        d = self._asdict()
-        d[key] = new_value
-        return StreamToken(**d)
+        return self._replace(**{key: new_value})
 
 
 StreamToken.START = StreamToken(
@@ -258,7 +338,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
             if string[0] == 't':
                 parts = string[1:].split('-', 1)
                 return cls(topological=int(parts[0]), stream=int(parts[1]))
-        except:
+        except Exception:
             pass
         raise SynapseError(400, "Invalid token %r" % (string,))
 
@@ -267,7 +347,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
         try:
             if string[0] == 's':
                 return cls(topological=None, stream=int(string[1:]))
-        except:
+        except Exception:
             pass
         raise SynapseError(400, "Invalid token %r" % (string,))
 
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 30fc480108..814a7bf71b 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.api.errors import SynapseError
 from synapse.util.logcontext import PreserveLoggingContext
 
 from twisted.internet import defer, reactor, task
@@ -24,11 +23,6 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class DeferredTimedOutError(SynapseError):
-    def __init__(self):
-        super(SynapseError).__init__(504, "Timed out")
-
-
 def unwrapFirstError(failure):
     # defer.gatherResults and DeferredLists wrap failures.
     failure.trap(defer.FirstError)
@@ -59,9 +53,9 @@ class Clock(object):
             f(function): The function to call repeatedly.
             msec(float): How long to wait between calls in milliseconds.
         """
-        l = task.LoopingCall(f)
-        l.start(msec / 1000.0, now=False)
-        return l
+        call = task.LoopingCall(f)
+        call.start(msec / 1000.0, now=False)
+        return call
 
     def call_later(self, delay, callback, *args, **kwargs):
         """Call something later
@@ -82,54 +76,6 @@ class Clock(object):
     def cancel_call_later(self, timer, ignore_errs=False):
         try:
             timer.cancel()
-        except:
+        except Exception:
             if not ignore_errs:
                 raise
-
-    def time_bound_deferred(self, given_deferred, time_out):
-        if given_deferred.called:
-            return given_deferred
-
-        ret_deferred = defer.Deferred()
-
-        def timed_out_fn():
-            try:
-                ret_deferred.errback(DeferredTimedOutError())
-            except:
-                pass
-
-            try:
-                given_deferred.cancel()
-            except:
-                pass
-
-        timer = None
-
-        def cancel(res):
-            try:
-                self.cancel_call_later(timer)
-            except:
-                pass
-            return res
-
-        ret_deferred.addBoth(cancel)
-
-        def sucess(res):
-            try:
-                ret_deferred.callback(res)
-            except:
-                pass
-
-            return res
-
-        def err(res):
-            try:
-                ret_deferred.errback(res)
-            except:
-                pass
-
-        given_deferred.addCallbacks(callback=sucess, errback=err)
-
-        timer = self.call_later(time_out, timed_out_fn)
-
-        return ret_deferred
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 35380bf8ed..9dd4e6b5bc 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -15,16 +15,20 @@
 
 
 from twisted.internet import defer, reactor
+from twisted.internet.defer import CancelledError
+from twisted.python import failure
 
 from .logcontext import (
-    PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
+    PreserveLoggingContext, make_deferred_yieldable, run_in_background
 )
-from synapse.util import unwrapFirstError
+from synapse.util import logcontext, unwrapFirstError
 
 from contextlib import contextmanager
 
 import logging
 
+from six.moves import range
+
 logger = logging.getLogger(__name__)
 
 
@@ -53,6 +57,11 @@ class ObservableDeferred(object):
 
     Cancelling or otherwise resolving an observer will not affect the original
     ObservableDeferred.
+
+    NB that it does not attempt to do anything with logcontexts; in general
+    you should probably make_deferred_yieldable the deferreds
+    returned by `observe`, and ensure that the original deferred runs its
+    callbacks in the sentinel logcontext.
     """
 
     __slots__ = ["_deferred", "_observers", "_result"]
@@ -68,7 +77,7 @@ class ObservableDeferred(object):
                 try:
                     # TODO: Handle errors here.
                     self._observers.pop().callback(r)
-                except:
+                except Exception:
                     pass
             return r
 
@@ -78,7 +87,7 @@ class ObservableDeferred(object):
                 try:
                     # TODO: Handle errors here.
                     self._observers.pop().errback(f)
-                except:
+                except Exception:
                     pass
 
             if consumeErrors:
@@ -89,6 +98,11 @@ class ObservableDeferred(object):
         deferred.addCallbacks(callback, errback)
 
     def observe(self):
+        """Observe the underlying deferred.
+
+        Can return either a deferred if the underlying deferred is still pending
+        (or has failed), or the actual value. Callers may need to use maybeDeferred.
+        """
         if not self._result:
             d = defer.Deferred()
 
@@ -101,7 +115,7 @@ class ObservableDeferred(object):
             return d
         else:
             success, res = self._result
-            return defer.succeed(res) if success else defer.fail(res)
+            return res if success else defer.fail(res)
 
     def observers(self):
         return self._observers
@@ -146,13 +160,13 @@ def concurrently_execute(func, args, limit):
     def _concurrently_execute_inner():
         try:
             while True:
-                yield func(it.next())
+                yield func(next(it))
         except StopIteration:
             pass
 
-    return preserve_context_over_deferred(defer.gatherResults([
-        preserve_fn(_concurrently_execute_inner)()
-        for _ in xrange(limit)
+    return logcontext.make_deferred_yieldable(defer.gatherResults([
+        run_in_background(_concurrently_execute_inner)
+        for _ in range(limit)
     ], consumeErrors=True)).addErrback(unwrapFirstError)
 
 
@@ -195,10 +209,29 @@ class Linearizer(object):
             try:
                 with PreserveLoggingContext():
                     yield current_defer
-            except:
+            except Exception:
                 logger.exception("Unexpected exception in Linearizer")
 
-        logger.info("Acquired linearizer lock %r for key %r", self.name, key)
+            logger.info("Acquired linearizer lock %r for key %r", self.name,
+                        key)
+
+            # if the code holding the lock completes synchronously, then it
+            # will recursively run the next claimant on the list. That can
+            # relatively rapidly lead to stack exhaustion. This is essentially
+            # the same problem as http://twistedmatrix.com/trac/ticket/9304.
+            #
+            # In order to break the cycle, we add a cheeky sleep(0) here to
+            # ensure that we fall back to the reactor between each iteration.
+            #
+            # (There's no particular need for it to happen before we return
+            # the context manager, but it needs to happen while we hold the
+            # lock, and the context manager's exit code must be synchronous,
+            # so actually this is the only sensible place.
+            yield run_on_reactor()
+
+        else:
+            logger.info("Acquired uncontended linearizer lock %r for key %r",
+                        self.name, key)
 
         @contextmanager
         def _ctx_manager():
@@ -206,7 +239,8 @@ class Linearizer(object):
                 yield
             finally:
                 logger.info("Releasing linearizer lock %r for key %r", self.name, key)
-                new_defer.callback(None)
+                with PreserveLoggingContext():
+                    new_defer.callback(None)
                 current_d = self.key_to_defer.get(key)
                 if current_d is new_defer:
                     self.key_to_defer.pop(key, None)
@@ -248,8 +282,13 @@ class Limiter(object):
         if entry[0] >= self.max_count:
             new_defer = defer.Deferred()
             entry[1].append(new_defer)
+
+            logger.info("Waiting to acquire limiter lock for key %r", key)
             with PreserveLoggingContext():
                 yield new_defer
+            logger.info("Acquired limiter lock for key %r", key)
+        else:
+            logger.info("Acquired uncontended limiter lock for key %r", key)
 
         entry[0] += 1
 
@@ -258,16 +297,21 @@ class Limiter(object):
             try:
                 yield
             finally:
+                logger.info("Releasing limiter lock for key %r", key)
+
                 # We've finished executing so check if there are any things
                 # blocked waiting to execute and start one of them
                 entry[0] -= 1
-                try:
-                    entry[1].pop(0).callback(None)
-                except IndexError:
-                    # If nothing else is executing for this key then remove it
-                    # from the map
-                    if entry[0] == 0:
-                        self.key_to_defer.pop(key, None)
+
+                if entry[1]:
+                    next_def = entry[1].pop(0)
+
+                    with PreserveLoggingContext():
+                        next_def.callback(None)
+                elif entry[0] == 0:
+                    # We were the last thing for this key: remove it from the
+                    # map.
+                    del self.key_to_defer[key]
 
         defer.returnValue(_ctx_manager())
 
@@ -311,7 +355,7 @@ class ReadWriteLock(object):
 
         # We wait for the latest writer to finish writing. We can safely ignore
         # any existing readers... as they're readers.
-        yield curr_writer
+        yield make_deferred_yieldable(curr_writer)
 
         @contextmanager
         def _ctx_manager():
@@ -340,7 +384,7 @@ class ReadWriteLock(object):
         curr_readers.clear()
         self.key_to_current_writer[key] = new_defer
 
-        yield preserve_context_over_deferred(defer.gatherResults(to_wait_on))
+        yield make_deferred_yieldable(defer.gatherResults(to_wait_on))
 
         @contextmanager
         def _ctx_manager():
@@ -352,3 +396,68 @@ class ReadWriteLock(object):
                     self.key_to_current_writer.pop(key)
 
         defer.returnValue(_ctx_manager())
+
+
+class DeferredTimeoutError(Exception):
+    """
+    This error is raised by default when a L{Deferred} times out.
+    """
+
+
+def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None):
+    """
+    Add a timeout to a deferred by scheduling it to be cancelled after
+    timeout seconds.
+
+    This is essentially a backport of deferred.addTimeout, which was introduced
+    in twisted 16.5.
+
+    If the deferred gets timed out, it errbacks with a DeferredTimeoutError,
+    unless a cancelable function was passed to its initialization or unless
+    a different on_timeout_cancel callable is provided.
+
+    Args:
+        deferred (defer.Deferred): deferred to be timed out
+        timeout (Number): seconds to time out after
+
+        on_timeout_cancel (callable): A callable which is called immediately
+            after the deferred times out, and not if this deferred is
+            otherwise cancelled before the timeout.
+
+            It takes an arbitrary value, which is the value of the deferred at
+            that exact point in time (probably a CancelledError Failure), and
+            the timeout.
+
+            The default callable (if none is provided) will translate a
+            CancelledError Failure into a DeferredTimeoutError.
+    """
+    timed_out = [False]
+
+    def time_it_out():
+        timed_out[0] = True
+        deferred.cancel()
+
+    delayed_call = reactor.callLater(timeout, time_it_out)
+
+    def convert_cancelled(value):
+        if timed_out[0]:
+            to_call = on_timeout_cancel or _cancelled_to_timed_out_error
+            return to_call(value, timeout)
+        return value
+
+    deferred.addBoth(convert_cancelled)
+
+    def cancel_timeout(result):
+        # stop the pending call to cancel the deferred if it's been fired
+        if delayed_call.active():
+            delayed_call.cancel()
+        return result
+
+    deferred.addBoth(cancel_timeout)
+
+
+def _cancelled_to_timed_out_error(value, timeout):
+    if isinstance(value, failure.Failure):
+        value.trap(CancelledError)
+        raise DeferredTimeoutError(timeout, "Deferred")
+    return value
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 8a7774a88e..4adae96681 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -14,12 +14,9 @@
 # limitations under the License.
 
 import synapse.metrics
-from lrucache import LruCache
 import os
 
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-
-DEBUG_CACHES = False
+CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5))
 
 metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
 
@@ -40,10 +37,6 @@ def register_cache(name, cache):
     )
 
 
-_string_cache = LruCache(int(100000 * CACHE_SIZE_FACTOR))
-_stirng_cache_metrics = register_cache("string_cache", _string_cache)
-
-
 KNOWN_KEYS = {
     key: key for key in
     (
@@ -67,14 +60,16 @@ KNOWN_KEYS = {
 
 
 def intern_string(string):
-    """Takes a (potentially) unicode string and interns using custom cache
+    """Takes a (potentially) unicode string and interns it if it's ascii
     """
-    new_str = _string_cache.setdefault(string, string)
-    if new_str is string:
-        _stirng_cache_metrics.inc_hits()
-    else:
-        _stirng_cache_metrics.inc_misses()
-    return new_str
+    if string is None:
+        return None
+
+    try:
+        string = string.encode("ascii")
+        return intern(string)
+    except UnicodeEncodeError:
+        return string
 
 
 def intern_dict(dictionary):
@@ -87,13 +82,9 @@ def intern_dict(dictionary):
 
 
 def _intern_known_values(key, value):
-    intern_str_keys = ("event_id", "room_id")
-    intern_unicode_keys = ("sender", "user_id", "type", "state_key")
-
-    if key in intern_str_keys:
-        return intern(value.encode('ascii'))
+    intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key",)
 
-    if key in intern_unicode_keys:
+    if key in intern_keys:
         return intern_string(value)
 
     return value
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 998de70d29..68285a7594 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,19 +16,17 @@
 import logging
 
 from synapse.util.async import ObservableDeferred
-from synapse.util import unwrapFirstError
+from synapse.util import unwrapFirstError, logcontext
+from synapse.util.caches import CACHE_SIZE_FACTOR
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-from synapse.util.logcontext import (
-    PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
-)
+from synapse.util.stringutils import to_ascii
 
-from . import DEBUG_CACHES, register_cache
+from . import register_cache
 
 from twisted.internet import defer
 from collections import namedtuple
 
-import os
 import functools
 import inspect
 import threading
@@ -39,17 +38,13 @@ logger = logging.getLogger(__name__)
 _CacheSentinel = object()
 
 
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-
-
 class CacheEntry(object):
     __slots__ = [
-        "deferred", "sequence", "callbacks", "invalidated"
+        "deferred", "callbacks", "invalidated"
     ]
 
-    def __init__(self, deferred, sequence, callbacks):
+    def __init__(self, deferred, callbacks):
         self.deferred = deferred
-        self.sequence = sequence
         self.callbacks = set(callbacks)
         self.invalidated = False
 
@@ -67,7 +62,6 @@ class Cache(object):
         "max_entries",
         "name",
         "keylen",
-        "sequence",
         "thread",
         "metrics",
         "_pending_deferred_cache",
@@ -79,15 +73,18 @@ class Cache(object):
 
         self.cache = LruCache(
             max_size=max_entries, keylen=keylen, cache_type=cache_type,
-            size_callback=(lambda d: len(d.result)) if iterable else None,
+            size_callback=(lambda d: len(d)) if iterable else None,
+            evicted_callback=self._on_evicted,
         )
 
         self.name = name
         self.keylen = keylen
-        self.sequence = 0
         self.thread = None
         self.metrics = register_cache(name, self.cache)
 
+    def _on_evicted(self, evicted_count):
+        self.metrics.inc_evictions(evicted_count)
+
     def check_thread(self):
         expected_thread = self.thread
         if expected_thread is None:
@@ -98,21 +95,34 @@ class Cache(object):
                     "Cache objects can only be accessed from the main thread"
                 )
 
-    def get(self, key, default=_CacheSentinel, callback=None):
+    def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
+        """Looks the key up in the caches.
+
+        Args:
+            key(tuple)
+            default: What is returned if key is not in the caches. If not
+                specified then function throws KeyError instead
+            callback(fn): Gets called when the entry in the cache is invalidated
+            update_metrics (bool): whether to update the cache hit rate metrics
+
+        Returns:
+            Either a Deferred or the raw result
+        """
         callbacks = [callback] if callback else []
         val = self._pending_deferred_cache.get(key, _CacheSentinel)
         if val is not _CacheSentinel:
-            if val.sequence == self.sequence:
-                val.callbacks.update(callbacks)
+            val.callbacks.update(callbacks)
+            if update_metrics:
                 self.metrics.inc_hits()
-                return val.deferred
+            return val.deferred
 
         val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
         if val is not _CacheSentinel:
             self.metrics.inc_hits()
             return val
 
-        self.metrics.inc_misses()
+        if update_metrics:
+            self.metrics.inc_misses()
 
         if default is _CacheSentinel:
             raise KeyError()
@@ -124,12 +134,9 @@ class Cache(object):
         self.check_thread()
         entry = CacheEntry(
             deferred=value,
-            sequence=self.sequence,
             callbacks=callbacks,
         )
 
-        entry.callbacks.update(callbacks)
-
         existing_entry = self._pending_deferred_cache.pop(key, None)
         if existing_entry:
             existing_entry.invalidate()
@@ -137,13 +144,25 @@ class Cache(object):
         self._pending_deferred_cache[key] = entry
 
         def shuffle(result):
-            if self.sequence == entry.sequence:
-                existing_entry = self._pending_deferred_cache.pop(key, None)
-                if existing_entry is entry:
-                    self.cache.set(key, entry.deferred, entry.callbacks)
-                else:
-                    entry.invalidate()
+            existing_entry = self._pending_deferred_cache.pop(key, None)
+            if existing_entry is entry:
+                self.cache.set(key, result, entry.callbacks)
             else:
+                # oops, the _pending_deferred_cache has been updated since
+                # we started our query, so we are out of date.
+                #
+                # Better put back whatever we took out. (We do it this way
+                # round, rather than peeking into the _pending_deferred_cache
+                # and then removing on a match, to make the common case faster)
+                if existing_entry is not None:
+                    self._pending_deferred_cache[key] = existing_entry
+
+                # we're not going to put this entry into the cache, so need
+                # to make sure that the invalidation callbacks are called.
+                # That was probably done when _pending_deferred_cache was
+                # updated, but it's possible that `set` was called without
+                # `invalidate` being previously called, in which case it may
+                # not have been. Either way, let's double-check now.
                 entry.invalidate()
             return result
 
@@ -155,29 +174,29 @@ class Cache(object):
 
     def invalidate(self, key):
         self.check_thread()
-        if not isinstance(key, tuple):
-            raise TypeError(
-                "The cache key must be a tuple not %r" % (type(key),)
-            )
+        self.cache.pop(key, None)
 
-        # Increment the sequence number so that any SELECT statements that
-        # raced with the INSERT don't update the cache (SYN-369)
-        self.sequence += 1
+        # if we have a pending lookup for this key, remove it from the
+        # _pending_deferred_cache, which will (a) stop it being returned
+        # for future queries and (b) stop it being persisted as a proper entry
+        # in self.cache.
         entry = self._pending_deferred_cache.pop(key, None)
+
+        # run the invalidation callbacks now, rather than waiting for the
+        # deferred to resolve.
         if entry:
             entry.invalidate()
 
-        self.cache.pop(key, None)
-
     def invalidate_many(self, key):
         self.check_thread()
         if not isinstance(key, tuple):
             raise TypeError(
                 "The cache key must be a tuple not %r" % (type(key),)
             )
-        self.sequence += 1
         self.cache.del_multi(key)
 
+        # if we have a pending lookup for this key, remove it from the
+        # _pending_deferred_cache, as above
         entry_dict = self._pending_deferred_cache.pop(key, None)
         if entry_dict is not None:
             for entry in iterate_tree_cache_entry(entry_dict):
@@ -185,11 +204,73 @@ class Cache(object):
 
     def invalidate_all(self):
         self.check_thread()
-        self.sequence += 1
         self.cache.clear()
+        for entry in self._pending_deferred_cache.itervalues():
+            entry.invalidate()
+        self._pending_deferred_cache.clear()
+
+
+class _CacheDescriptorBase(object):
+    def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
+        self.orig = orig
+
+        if inlineCallbacks:
+            self.function_to_call = defer.inlineCallbacks(orig)
+        else:
+            self.function_to_call = orig
 
+        arg_spec = inspect.getargspec(orig)
+        all_args = arg_spec.args
 
-class CacheDescriptor(object):
+        if "cache_context" in all_args:
+            if not cache_context:
+                raise ValueError(
+                    "Cannot have a 'cache_context' arg without setting"
+                    " cache_context=True"
+                )
+        elif cache_context:
+            raise ValueError(
+                "Cannot have cache_context=True without having an arg"
+                " named `cache_context`"
+            )
+
+        if num_args is None:
+            num_args = len(all_args) - 1
+            if cache_context:
+                num_args -= 1
+
+        if len(all_args) < num_args + 1:
+            raise Exception(
+                "Not enough explicit positional arguments to key off for %r: "
+                "got %i args, but wanted %i. (@cached cannot key off *args or "
+                "**kwargs)"
+                % (orig.__name__, len(all_args), num_args)
+            )
+
+        self.num_args = num_args
+
+        # list of the names of the args used as the cache key
+        self.arg_names = all_args[1:num_args + 1]
+
+        # self.arg_defaults is a map of arg name to its default value for each
+        # argument that has a default value
+        if arg_spec.defaults:
+            self.arg_defaults = dict(zip(
+                all_args[-len(arg_spec.defaults):],
+                arg_spec.defaults
+            ))
+        else:
+            self.arg_defaults = {}
+
+        if "cache_context" in self.arg_names:
+            raise Exception(
+                "cache_context arg cannot be included among the cache keys"
+            )
+
+        self.add_cache_context = cache_context
+
+
+class CacheDescriptor(_CacheDescriptorBase):
     """ A method decorator that applies a memoizing cache around the function.
 
     This caches deferreds, rather than the results themselves. Deferreds that
@@ -217,52 +298,24 @@ class CacheDescriptor(object):
             r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
             defer.returnValue(r1 + r2)
 
+    Args:
+        num_args (int): number of positional arguments (excluding ``self`` and
+            ``cache_context``) to use as cache keys. Defaults to all named
+            args of the function.
     """
-    def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
+    def __init__(self, orig, max_entries=1000, num_args=None, tree=False,
                  inlineCallbacks=False, cache_context=False, iterable=False):
-        max_entries = int(max_entries * CACHE_SIZE_FACTOR)
 
-        self.orig = orig
+        super(CacheDescriptor, self).__init__(
+            orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
+            cache_context=cache_context)
 
-        if inlineCallbacks:
-            self.function_to_call = defer.inlineCallbacks(orig)
-        else:
-            self.function_to_call = orig
+        max_entries = int(max_entries * CACHE_SIZE_FACTOR)
 
         self.max_entries = max_entries
-        self.num_args = num_args
         self.tree = tree
-
         self.iterable = iterable
 
-        all_args = inspect.getargspec(orig)
-        self.arg_names = all_args.args[1:num_args + 1]
-
-        if "cache_context" in all_args.args:
-            if not cache_context:
-                raise ValueError(
-                    "Cannot have a 'cache_context' arg without setting"
-                    " cache_context=True"
-                )
-            try:
-                self.arg_names.remove("cache_context")
-            except ValueError:
-                pass
-        elif cache_context:
-            raise ValueError(
-                "Cannot have cache_context=True without having an arg"
-                " named `cache_context`"
-            )
-
-        self.add_cache_context = cache_context
-
-        if len(self.arg_names) < self.num_args:
-            raise Exception(
-                "Not enough explicit positional arguments to key off of for %r."
-                " (@cached cannot key off of *args or **kwargs)"
-                % (orig.__name__,)
-            )
-
     def __get__(self, obj, objtype=None):
         cache = Cache(
             name=self.orig.__name__,
@@ -272,18 +325,47 @@ class CacheDescriptor(object):
             iterable=self.iterable,
         )
 
+        def get_cache_key_gen(args, kwargs):
+            """Given some args/kwargs return a generator that resolves into
+            the cache_key.
+
+            We loop through each arg name, looking up if its in the `kwargs`,
+            otherwise using the next argument in `args`. If there are no more
+            args then we try looking the arg name up in the defaults
+            """
+            pos = 0
+            for nm in self.arg_names:
+                if nm in kwargs:
+                    yield kwargs[nm]
+                elif pos < len(args):
+                    yield args[pos]
+                    pos += 1
+                else:
+                    yield self.arg_defaults[nm]
+
+        # By default our cache key is a tuple, but if there is only one item
+        # then don't bother wrapping in a tuple.  This is to save memory.
+        if self.num_args == 1:
+            nm = self.arg_names[0]
+
+            def get_cache_key(args, kwargs):
+                if nm in kwargs:
+                    return kwargs[nm]
+                elif len(args):
+                    return args[0]
+                else:
+                    return self.arg_defaults[nm]
+        else:
+            def get_cache_key(args, kwargs):
+                return tuple(get_cache_key_gen(args, kwargs))
+
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
             # If we're passed a cache_context then we'll want to call its invalidate()
             # whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
 
-            # Add temp cache_context so inspect.getcallargs doesn't explode
-            if self.add_cache_context:
-                kwargs["cache_context"] = None
-
-            arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
-            cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
+            cache_key = get_cache_key(args, kwargs)
 
             # Add our own `cache_context` to argument list if the wrapped function
             # has asked for one
@@ -293,26 +375,14 @@ class CacheDescriptor(object):
             try:
                 cached_result_d = cache.get(cache_key, callback=invalidate_callback)
 
-                observer = cached_result_d.observe()
-                if DEBUG_CACHES:
-                    @defer.inlineCallbacks
-                    def check_result(cached_result):
-                        actual_result = yield self.function_to_call(obj, *args, **kwargs)
-                        if actual_result != cached_result:
-                            logger.error(
-                                "Stale cache entry %s%r: cached: %r, actual %r",
-                                self.orig.__name__, cache_key,
-                                cached_result, actual_result,
-                            )
-                            raise ValueError("Stale cache entry")
-                        defer.returnValue(cached_result)
-                    observer.addCallback(check_result)
-
-                return preserve_context_over_deferred(observer)
+                if isinstance(cached_result_d, ObservableDeferred):
+                    observer = cached_result_d.observe()
+                else:
+                    observer = cached_result_d
+
             except KeyError:
                 ret = defer.maybeDeferred(
-                    preserve_context_over_fn,
-                    self.function_to_call,
+                    logcontext.preserve_fn(self.function_to_call),
                     obj, *args, **kwargs
                 )
 
@@ -322,64 +392,72 @@ class CacheDescriptor(object):
 
                 ret.addErrback(onErr)
 
-                ret = ObservableDeferred(ret, consumeErrors=True)
-                cache.set(cache_key, ret, callback=invalidate_callback)
+                # If our cache_key is a string, try to convert to ascii to save
+                # a bit of space in large caches
+                if isinstance(cache_key, basestring):
+                    cache_key = to_ascii(cache_key)
 
-                return preserve_context_over_deferred(ret.observe())
+                result_d = ObservableDeferred(ret, consumeErrors=True)
+                cache.set(cache_key, result_d, callback=invalidate_callback)
+                observer = result_d.observe()
+
+            if isinstance(observer, defer.Deferred):
+                return logcontext.make_deferred_yieldable(observer)
+            else:
+                return observer
+
+        if self.num_args == 1:
+            wrapped.invalidate = lambda key: cache.invalidate(key[0])
+            wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
+        else:
+            wrapped.invalidate = cache.invalidate
+            wrapped.invalidate_all = cache.invalidate_all
+            wrapped.invalidate_many = cache.invalidate_many
+            wrapped.prefill = cache.prefill
 
-        wrapped.invalidate = cache.invalidate
         wrapped.invalidate_all = cache.invalidate_all
-        wrapped.invalidate_many = cache.invalidate_many
-        wrapped.prefill = cache.prefill
         wrapped.cache = cache
+        wrapped.num_args = self.num_args
 
         obj.__dict__[self.orig.__name__] = wrapped
 
         return wrapped
 
 
-class CacheListDescriptor(object):
+class CacheListDescriptor(_CacheDescriptorBase):
     """Wraps an existing cache to support bulk fetching of keys.
 
     Given a list of keys it looks in the cache to find any hits, then passes
-    the list of missing keys to the wrapped fucntion.
+    the list of missing keys to the wrapped function.
+
+    Once wrapped, the function returns either a Deferred which resolves to
+    the list of results, or (if all results were cached), just the list of
+    results.
     """
 
-    def __init__(self, orig, cached_method_name, list_name, num_args=1,
+    def __init__(self, orig, cached_method_name, list_name, num_args=None,
                  inlineCallbacks=False):
         """
         Args:
             orig (function)
-            method_name (str); The name of the chached method.
+            cached_method_name (str): The name of the chached method.
             list_name (str): Name of the argument which is the bulk lookup list
-            num_args (int)
+            num_args (int): number of positional arguments (excluding ``self``,
+                but including list_name) to use as cache keys. Defaults to all
+                named args of the function.
             inlineCallbacks (bool): Whether orig is a generator that should
                 be wrapped by defer.inlineCallbacks
         """
-        self.orig = orig
+        super(CacheListDescriptor, self).__init__(
+            orig, num_args=num_args, inlineCallbacks=inlineCallbacks)
 
-        if inlineCallbacks:
-            self.function_to_call = defer.inlineCallbacks(orig)
-        else:
-            self.function_to_call = orig
-
-        self.num_args = num_args
         self.list_name = list_name
 
-        self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
         self.list_pos = self.arg_names.index(self.list_name)
-
         self.cached_method_name = cached_method_name
 
         self.sentinel = object()
 
-        if len(self.arg_names) < self.num_args:
-            raise Exception(
-                "Not enough explicit positional arguments to key off of for %r."
-                " (@cached cannot key off of *args or **kwars)"
-                % (orig.__name__,)
-            )
-
         if self.list_name not in self.arg_names:
             raise Exception(
                 "Couldn't see arguments %r for %r."
@@ -387,8 +465,9 @@ class CacheListDescriptor(object):
             )
 
     def __get__(self, obj, objtype=None):
-
-        cache = getattr(obj, self.cached_method_name).cache
+        cached_method = getattr(obj, self.cached_method_name)
+        cache = cached_method.cache
+        num_args = cached_method.num_args
 
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
@@ -405,13 +484,26 @@ class CacheListDescriptor(object):
             results = {}
             cached_defers = {}
             missing = []
-            for arg in list_args:
+
+            # If the cache takes a single arg then that is used as the key,
+            # otherwise a tuple is used.
+            if num_args == 1:
+                def cache_get(arg):
+                    return cache.get(arg, callback=invalidate_callback)
+            else:
                 key = list(keyargs)
-                key[self.list_pos] = arg
 
+                def cache_get(arg):
+                    key[self.list_pos] = arg
+                    return cache.get(tuple(key), callback=invalidate_callback)
+
+            for arg in list_args:
                 try:
-                    res = cache.get(tuple(key), callback=invalidate_callback)
-                    if not res.has_succeeded():
+                    res = cache_get(arg)
+
+                    if not isinstance(res, ObservableDeferred):
+                        results[arg] = res
+                    elif not res.has_succeeded():
                         res = res.observe()
                         res.addCallback(lambda r, arg: (arg, r), arg)
                         cached_defers[arg] = res
@@ -425,8 +517,7 @@ class CacheListDescriptor(object):
                 args_to_call[self.list_name] = missing
 
                 ret_d = defer.maybeDeferred(
-                    preserve_context_over_fn,
-                    self.function_to_call,
+                    logcontext.preserve_fn(self.function_to_call),
                     **args_to_call
                 )
 
@@ -435,23 +526,33 @@ class CacheListDescriptor(object):
                 # We need to create deferreds for each arg in the list so that
                 # we can insert the new deferred into the cache.
                 for arg in missing:
-                    with PreserveLoggingContext():
-                        observer = ret_d.observe()
+                    observer = ret_d.observe()
                     observer.addCallback(lambda r, arg: r.get(arg, None), arg)
 
                     observer = ObservableDeferred(observer)
 
-                    key = list(keyargs)
-                    key[self.list_pos] = arg
-                    cache.set(
-                        tuple(key), observer,
-                        callback=invalidate_callback
-                    )
+                    if num_args == 1:
+                        cache.set(
+                            arg, observer,
+                            callback=invalidate_callback
+                        )
 
-                    def invalidate(f, key):
-                        cache.invalidate(key)
-                        return f
-                    observer.addErrback(invalidate, tuple(key))
+                        def invalidate(f, key):
+                            cache.invalidate(key)
+                            return f
+                        observer.addErrback(invalidate, arg)
+                    else:
+                        key = list(keyargs)
+                        key[self.list_pos] = arg
+                        cache.set(
+                            tuple(key), observer,
+                            callback=invalidate_callback
+                        )
+
+                        def invalidate(f, key):
+                            cache.invalidate(key)
+                            return f
+                        observer.addErrback(invalidate, tuple(key))
 
                     res = observer.observe()
                     res.addCallback(lambda r, arg: (arg, r), arg)
@@ -463,7 +564,7 @@ class CacheListDescriptor(object):
                     results.update(res)
                     return results
 
-                return preserve_context_over_deferred(defer.gatherResults(
+                return logcontext.make_deferred_yieldable(defer.gatherResults(
                     cached_defers.values(),
                     consumeErrors=True,
                 ).addCallback(update_results_dict).addErrback(
@@ -487,7 +588,7 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
         self.cache.invalidate(self.key)
 
 
-def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
+def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
            iterable=False):
     return lambda orig: CacheDescriptor(
         orig,
@@ -499,8 +600,8 @@ def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
     )
 
 
-def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
-                          iterable=False):
+def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False,
+                          cache_context=False, iterable=False):
     return lambda orig: CacheDescriptor(
         orig,
         max_entries=max_entries,
@@ -512,7 +613,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
     )
 
 
-def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False):
+def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False):
     """Creates a descriptor that wraps a function in a `CacheListDescriptor`.
 
     Used to do batch lookups for an already created cache. A single argument
@@ -525,7 +626,8 @@ def cachedList(cached_method_name, list_name, num_args=1, inlineCallbacks=False)
         cache (Cache): The underlying cache to use.
         list_name (str): The name of the argument that is the list to use to
             do batch lookups in the cache.
-        num_args (int): Number of arguments to use as the key in the cache.
+        num_args (int): Number of arguments to use as the key in the cache
+            (including list_name). Defaults to all named parameters.
         inlineCallbacks (bool): Should the function be wrapped in an
             `defer.inlineCallbacks`?
 
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index cb6933c61c..1709e8b429 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -23,7 +23,17 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))):
+class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "value"))):
+    """Returned when getting an entry from the cache
+
+    Attributes:
+        full (bool): Whether the cache has the full or dict or just some keys.
+            If not full then not all requested keys will necessarily be present
+            in `value`
+        known_absent (set): Keys that were looked up in the dict and were not
+            there.
+        value (dict): The full or partial dict value
+    """
     def __len__(self):
         return len(self.value)
 
@@ -58,21 +68,31 @@ class DictionaryCache(object):
                 )
 
     def get(self, key, dict_keys=None):
+        """Fetch an entry out of the cache
+
+        Args:
+            key
+            dict_key(list): If given a set of keys then return only those keys
+                that exist in the cache.
+
+        Returns:
+            DictionaryEntry
+        """
         entry = self.cache.get(key, self.sentinel)
         if entry is not self.sentinel:
             self.metrics.inc_hits()
 
             if dict_keys is None:
-                return DictionaryEntry(entry.full, dict(entry.value))
+                return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value))
             else:
-                return DictionaryEntry(entry.full, {
+                return DictionaryEntry(entry.full, entry.known_absent, {
                     k: entry.value[k]
                     for k in dict_keys
                     if k in entry.value
                 })
 
         self.metrics.inc_misses()
-        return DictionaryEntry(False, {})
+        return DictionaryEntry(False, set(), {})
 
     def invalidate(self, key):
         self.check_thread()
@@ -87,19 +107,38 @@ class DictionaryCache(object):
         self.sequence += 1
         self.cache.clear()
 
-    def update(self, sequence, key, value, full=False):
+    def update(self, sequence, key, value, full=False, known_absent=None):
+        """Updates the entry in the cache
+
+        Args:
+            sequence
+            key
+            value (dict): The value to update the cache with.
+            full (bool): Whether the given value is the full dict, or just a
+                partial subset there of. If not full then any existing entries
+                for the key will be updated.
+            known_absent (set): Set of keys that we know don't exist in the full
+                dict.
+        """
         self.check_thread()
         if self.sequence == sequence:
             # Only update the cache if the caches sequence number matches the
             # number that the cache had before the SELECT was started (SYN-369)
+            if known_absent is None:
+                known_absent = set()
             if full:
-                self._insert(key, value)
+                self._insert(key, value, known_absent)
             else:
-                self._update_or_insert(key, value)
+                self._update_or_insert(key, value, known_absent)
+
+    def _update_or_insert(self, key, value, known_absent):
+        # We pop and reinsert as we need to tell the cache the size may have
+        # changed
 
-    def _update_or_insert(self, key, value):
-        entry = self.cache.setdefault(key, DictionaryEntry(False, {}))
+        entry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
         entry.value.update(value)
+        entry.known_absent.update(known_absent)
+        self.cache[key] = entry
 
-    def _insert(self, key, value):
-        self.cache[key] = DictionaryEntry(True, value)
+    def _insert(self, key, value, known_absent):
+        self.cache[key] = DictionaryEntry(True, known_absent, value)
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 2987c38a2d..0aa103eecb 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -79,7 +79,11 @@ class ExpiringCache(object):
         while self._max_len and len(self) > self._max_len:
             _key, value = self._cache.popitem(last=False)
             if self.iterable:
-                self._size_estimate -= len(value.value)
+                removed_len = len(value.value)
+                self.metrics.inc_evictions(removed_len)
+                self._size_estimate -= removed_len
+            else:
+                self.metrics.inc_evictions()
 
     def __getitem__(self, key):
         try:
@@ -94,12 +98,22 @@ class ExpiringCache(object):
 
         return entry.value
 
+    def __contains__(self, key):
+        return key in self._cache
+
     def get(self, key, default=None):
         try:
             return self[key]
         except KeyError:
             return default
 
+    def setdefault(self, key, value):
+        try:
+            return self[key]
+        except KeyError:
+            self[key] = value
+            return value
+
     def _prune_cache(self):
         if not self._expiry_ms:
             # zero expiry time means don't expire. This should never get called
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index cf5fbb679c..1c5a982094 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -49,7 +49,24 @@ class LruCache(object):
     Can also set callbacks on objects when getting/setting which are fired
     when that key gets invalidated/evicted.
     """
-    def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
+    def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None,
+                 evicted_callback=None):
+        """
+        Args:
+            max_size (int):
+
+            keylen (int):
+
+            cache_type (type):
+                type of underlying cache to be used. Typically one of dict
+                or TreeCache.
+
+            size_callback (func(V) -> int | None):
+
+            evicted_callback (func(int)|None):
+                if not None, called on eviction with the size of the evicted
+                entry
+        """
         cache = cache_type()
         self.cache = cache  # Used for introspection.
         list_root = _Node(None, None, None, None)
@@ -61,8 +78,10 @@ class LruCache(object):
         def evict():
             while cache_len() > max_size:
                 todelete = list_root.prev_node
-                delete_node(todelete)
+                evicted_len = delete_node(todelete)
                 cache.pop(todelete.key, None)
+                if evicted_callback:
+                    evicted_callback(evicted_len)
 
         def synchronized(f):
             @wraps(f)
@@ -111,12 +130,15 @@ class LruCache(object):
             prev_node.next_node = next_node
             next_node.prev_node = prev_node
 
+            deleted_len = 1
             if size_callback:
-                cached_cache_len[0] -= size_callback(node.value)
+                deleted_len = size_callback(node.value)
+                cached_cache_len[0] -= deleted_len
 
             for cb in node.callbacks:
                 cb()
             node.callbacks.clear()
+            return deleted_len
 
         @synchronized
         def cache_get(key, default=None, callbacks=[]):
@@ -132,14 +154,21 @@ class LruCache(object):
         def cache_set(key, value, callbacks=[]):
             node = cache.get(key, None)
             if node is not None:
-                if value != node.value:
+                # We sometimes store large objects, e.g. dicts, which cause
+                # the inequality check to take a long time. So let's only do
+                # the check if we have some callbacks to call.
+                if node.callbacks and value != node.value:
                     for cb in node.callbacks:
                         cb()
                     node.callbacks.clear()
 
-                    if size_callback:
-                        cached_cache_len[0] -= size_callback(node.value)
-                        cached_cache_len[0] += size_callback(value)
+                # We don't bother to protect this by value != node.value as
+                # generally size_callback will be cheap compared with equality
+                # checks. (For example, taking the size of two dicts is quicker
+                # than comparing them for equality.)
+                if size_callback:
+                    cached_cache_len[0] -= size_callback(node.value)
+                    cached_cache_len[0] += size_callback(value)
 
                 node.callbacks.update(callbacks)
 
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 00af539880..7f79333e96 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -12,8 +12,15 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
+
+from twisted.internet import defer
 
 from synapse.util.async import ObservableDeferred
+from synapse.util.caches import metrics as cache_metrics
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+
+logger = logging.getLogger(__name__)
 
 
 class ResponseCache(object):
@@ -24,20 +31,68 @@ class ResponseCache(object):
     used rather than trying to compute a new response.
     """
 
-    def __init__(self, hs, timeout_ms=0):
+    def __init__(self, hs, name, timeout_ms=0):
         self.pending_result_cache = {}  # Requests that haven't finished yet.
 
         self.clock = hs.get_clock()
         self.timeout_sec = timeout_ms / 1000.
 
+        self._name = name
+        self._metrics = cache_metrics.register_cache(
+            "response_cache",
+            size_callback=lambda: self.size(),
+            cache_name=name,
+        )
+
+    def size(self):
+        return len(self.pending_result_cache)
+
     def get(self, key):
+        """Look up the given key.
+
+        Can return either a new Deferred (which also doesn't follow the synapse
+        logcontext rules), or, if the request has completed, the actual
+        result. You will probably want to make_deferred_yieldable the result.
+
+        If there is no entry for the key, returns None. It is worth noting that
+        this means there is no way to distinguish a completed result of None
+        from an absent cache entry.
+
+        Args:
+            key (hashable):
+
+        Returns:
+            twisted.internet.defer.Deferred|None|E: None if there is no entry
+            for this key; otherwise either a deferred result or the result
+            itself.
+        """
         result = self.pending_result_cache.get(key)
         if result is not None:
+            self._metrics.inc_hits()
             return result.observe()
         else:
+            self._metrics.inc_misses()
             return None
 
     def set(self, key, deferred):
+        """Set the entry for the given key to the given deferred.
+
+        *deferred* should run its callbacks in the sentinel logcontext (ie,
+        you should wrap normal synapse deferreds with
+        logcontext.run_in_background).
+
+        Can return either a new Deferred (which also doesn't follow the synapse
+        logcontext rules), or, if *deferred* was already complete, the actual
+        result. You will probably want to make_deferred_yieldable the result.
+
+        Args:
+            key (hashable):
+            deferred (twisted.internet.defer.Deferred[T):
+
+        Returns:
+            twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual
+                result.
+        """
         result = ObservableDeferred(deferred, consumeErrors=True)
         self.pending_result_cache[key] = result
 
@@ -53,3 +108,52 @@ class ResponseCache(object):
 
         result.addBoth(remove)
         return result.observe()
+
+    def wrap(self, key, callback, *args, **kwargs):
+        """Wrap together a *get* and *set* call, taking care of logcontexts
+
+        First looks up the key in the cache, and if it is present makes it
+        follow the synapse logcontext rules and returns it.
+
+        Otherwise, makes a call to *callback(*args, **kwargs)*, which should
+        follow the synapse logcontext rules, and adds the result to the cache.
+
+        Example usage:
+
+            @defer.inlineCallbacks
+            def handle_request(request):
+                # etc
+                defer.returnValue(result)
+
+            result = yield response_cache.wrap(
+                key,
+                handle_request,
+                request,
+            )
+
+        Args:
+            key (hashable): key to get/set in the cache
+
+            callback (callable): function to call if the key is not found in
+                the cache
+
+            *args: positional parameters to pass to the callback, if it is used
+
+            **kwargs: named paramters to pass to the callback, if it is used
+
+        Returns:
+            twisted.internet.defer.Deferred: yieldable result
+        """
+        result = self.get(key)
+        if not result:
+            logger.info("[%s]: no cached result for [%s], calculating new one",
+                        self._name, key)
+            d = run_in_background(callback, *args, **kwargs)
+            result = self.set(key, d)
+        elif not isinstance(result, defer.Deferred) or result.called:
+            logger.info("[%s]: using completed cached result for [%s]",
+                        self._name, key)
+        else:
+            logger.info("[%s]: using incomplete cached result for [%s]",
+                        self._name, key)
+        return make_deferred_yieldable(result)
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index b72bb0ff02..941d873ab8 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -13,20 +13,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util.caches import register_cache
+from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR
 
 
 from blist import sorteddict
 import logging
-import os
 
 
 logger = logging.getLogger(__name__)
 
 
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
-
-
 class StreamChangeCache(object):
     """Keeps track of the stream positions of the latest change in a set of entities.
 
@@ -50,7 +46,7 @@ class StreamChangeCache(object):
     def has_entity_changed(self, entity, stream_pos):
         """Returns True if the entity may have been updated since stream_pos
         """
-        assert type(stream_pos) is int
+        assert type(stream_pos) is int or type(stream_pos) is long
 
         if stream_pos < self._earliest_known_stream_pos:
             self.metrics.inc_misses()
@@ -89,6 +85,21 @@ class StreamChangeCache(object):
 
         return result
 
+    def has_any_entity_changed(self, stream_pos):
+        """Returns if any entity has changed
+        """
+        assert type(stream_pos) is int
+
+        if stream_pos >= self._earliest_known_stream_pos:
+            self.metrics.inc_hits()
+            keys = self._cache.keys()
+            i = keys.bisect_right(stream_pos)
+
+            return i < len(keys)
+        else:
+            self.metrics.inc_misses()
+            return True
+
     def get_all_entities_changed(self, stream_pos):
         """Returns all entites that have had new things since the given
         position. If the position is too old it will return None.
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index e68f94ce77..734331caaa 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -13,32 +13,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
+import logging
 
-from synapse.util.logcontext import (
-    PreserveLoggingContext, preserve_context_over_fn
-)
+from twisted.internet import defer
 
 from synapse.util import unwrapFirstError
-
-import logging
-
+from synapse.util.logcontext import PreserveLoggingContext
 
 logger = logging.getLogger(__name__)
 
 
 def user_left_room(distributor, user, room_id):
-    return preserve_context_over_fn(
-        distributor.fire,
-        "user_left_room", user=user, room_id=room_id
-    )
+    with PreserveLoggingContext():
+        distributor.fire("user_left_room", user=user, room_id=room_id)
 
 
 def user_joined_room(distributor, user, room_id):
-    return preserve_context_over_fn(
-        distributor.fire,
-        "user_joined_room", user=user, room_id=room_id
-    )
+    with PreserveLoggingContext():
+        distributor.fire("user_joined_room", user=user, room_id=room_id)
 
 
 class Distributor(object):
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
new file mode 100644
index 0000000000..3380970e4e
--- /dev/null
+++ b/synapse/util/file_consumer.py
@@ -0,0 +1,141 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import threads, reactor
+
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+
+from six.moves import queue
+
+
+class BackgroundFileConsumer(object):
+    """A consumer that writes to a file like object. Supports both push
+    and pull producers
+
+    Args:
+        file_obj (file): The file like object to write to. Closed when
+            finished.
+    """
+
+    # For PushProducers pause if we have this many unwritten slices
+    _PAUSE_ON_QUEUE_SIZE = 5
+    # And resume once the size of the queue is less than this
+    _RESUME_ON_QUEUE_SIZE = 2
+
+    def __init__(self, file_obj):
+        self._file_obj = file_obj
+
+        # Producer we're registered with
+        self._producer = None
+
+        # True if PushProducer, false if PullProducer
+        self.streaming = False
+
+        # For PushProducers, indicates whether we've paused the producer and
+        # need to call resumeProducing before we get more data.
+        self._paused_producer = False
+
+        # Queue of slices of bytes to be written. When producer calls
+        # unregister a final None is sent.
+        self._bytes_queue = queue.Queue()
+
+        # Deferred that is resolved when finished writing
+        self._finished_deferred = None
+
+        # If the _writer thread throws an exception it gets stored here.
+        self._write_exception = None
+
+    def registerProducer(self, producer, streaming):
+        """Part of IConsumer interface
+
+        Args:
+            producer (IProducer)
+            streaming (bool): True if push based producer, False if pull
+                based.
+        """
+        if self._producer:
+            raise Exception("registerProducer called twice")
+
+        self._producer = producer
+        self.streaming = streaming
+        self._finished_deferred = run_in_background(
+            threads.deferToThread, self._writer
+        )
+        if not streaming:
+            self._producer.resumeProducing()
+
+    def unregisterProducer(self):
+        """Part of IProducer interface
+        """
+        self._producer = None
+        if not self._finished_deferred.called:
+            self._bytes_queue.put_nowait(None)
+
+    def write(self, bytes):
+        """Part of IProducer interface
+        """
+        if self._write_exception:
+            raise self._write_exception
+
+        if self._finished_deferred.called:
+            raise Exception("consumer has closed")
+
+        self._bytes_queue.put_nowait(bytes)
+
+        # If this is a PushProducer and the queue is getting behind
+        # then we pause the producer.
+        if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE:
+            self._paused_producer = True
+            self._producer.pauseProducing()
+
+    def _writer(self):
+        """This is run in a background thread to write to the file.
+        """
+        try:
+            while self._producer or not self._bytes_queue.empty():
+                # If we've paused the producer check if we should resume the
+                # producer.
+                if self._producer and self._paused_producer:
+                    if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
+                        reactor.callFromThread(self._resume_paused_producer)
+
+                bytes = self._bytes_queue.get()
+
+                # If we get a None (or empty list) then that's a signal used
+                # to indicate we should check if we should stop.
+                if bytes:
+                    self._file_obj.write(bytes)
+
+                # If its a pull producer then we need to explicitly ask for
+                # more stuff.
+                if not self.streaming and self._producer:
+                    reactor.callFromThread(self._producer.resumeProducing)
+        except Exception as e:
+            self._write_exception = e
+            raise
+        finally:
+            self._file_obj.close()
+
+    def wait(self):
+        """Returns a deferred that resolves when finished writing to file
+        """
+        return make_deferred_yieldable(self._finished_deferred)
+
+    def _resume_paused_producer(self):
+        """Gets called if we should resume producing after being paused
+        """
+        if self._paused_producer and self._producer:
+            self._paused_producer = False
+            self._producer.resumeProducing()
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 6322f0f55c..f497b51f4a 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 from frozendict import frozendict
+import simplejson as json
 
 
 def freeze(o):
@@ -49,3 +50,21 @@ def unfreeze(o):
         pass
 
     return o
+
+
+def _handle_frozendict(obj):
+    """Helper for EventEncoder. Makes frozendicts serializable by returning
+    the underlying dict
+    """
+    if type(obj) is frozendict:
+        # fishing the protected dict out of the object is a bit nasty,
+        # but we don't really want the overhead of copying the dict.
+        return obj._dict
+    raise TypeError('Object of type %s is not JSON serializable' %
+                    obj.__class__.__name__)
+
+
+# A JSONEncoder which is capable of encoding frozendics without barfing
+frozendict_json_encoder = json.JSONEncoder(
+    default=_handle_frozendict,
+)
diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py
index 45be47159a..e9f0f292ee 100644
--- a/synapse/util/httpresourcetree.py
+++ b/synapse/util/httpresourcetree.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.web.resource import Resource
+from twisted.web.resource import NoResource
 
 import logging
 
@@ -40,12 +40,15 @@ def create_resource_tree(desired_tree, root_resource):
     # extra resources to existing nodes. See self._resource_id for the key.
     resource_mappings = {}
     for full_path, res in desired_tree.items():
+        # twisted requires all resources to be bytes
+        full_path = full_path.encode("utf-8")
+
         logger.info("Attaching %s to path %s", res, full_path)
         last_resource = root_resource
-        for path_seg in full_path.split('/')[1:-1]:
+        for path_seg in full_path.split(b'/')[1:-1]:
             if path_seg not in last_resource.listNames():
                 # resource doesn't exist, so make a "dummy resource"
-                child_resource = Resource()
+                child_resource = NoResource()
                 last_resource.putChild(path_seg, child_resource)
                 res_id = _resource_id(last_resource, path_seg)
                 resource_mappings[res_id] = child_resource
@@ -57,7 +60,7 @@ def create_resource_tree(desired_tree, root_resource):
 
         # ===========================
         # now attach the actual desired resource
-        last_path_seg = full_path.split('/')[-1]
+        last_path_seg = full_path.split(b'/')[-1]
 
         # if there is already a resource here, thieve its children and
         # replace it
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 6c83eb213d..e086e12213 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -12,6 +12,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+""" Thread-local-alike tracking of log contexts within synapse
+
+This module provides objects and utilities for tracking contexts through
+synapse code, so that log lines can include a request identifier, and so that
+CPU and database activity can be accounted for against the request that caused
+them.
+
+See doc/log_contexts.rst for details on how this works.
+"""
+
 from twisted.internet import defer
 
 import threading
@@ -32,7 +42,7 @@ try:
 
     def get_thread_resource_usage():
         return resource.getrusage(RUSAGE_THREAD)
-except:
+except Exception:
     # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we
     # won't track resource usage by returning None.
     def get_thread_resource_usage():
@@ -42,13 +52,17 @@ except:
 class LoggingContext(object):
     """Additional context for log formatting. Contexts are scoped within a
     "with" block.
+
     Args:
         name (str): Name for the context for debugging.
     """
 
     __slots__ = [
-        "previous_context", "name", "usage_start", "usage_end", "main_thread",
-        "__dict__", "tag", "alive",
+        "previous_context", "name", "ru_stime", "ru_utime",
+        "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
+        "usage_start", "usage_end",
+        "main_thread", "alive",
+        "request", "tag",
     ]
 
     thread_local = threading.local()
@@ -73,8 +87,12 @@ class LoggingContext(object):
         def add_database_transaction(self, duration_ms):
             pass
 
+        def add_database_scheduled(self, sched_ms):
+            pass
+
         def __nonzero__(self):
             return False
+        __bool__ = __nonzero__  # python3
 
     sentinel = Sentinel()
 
@@ -84,9 +102,17 @@ class LoggingContext(object):
         self.ru_stime = 0.
         self.ru_utime = 0.
         self.db_txn_count = 0
-        self.db_txn_duration = 0.
+
+        # ms spent waiting for db txns, excluding scheduling time
+        self.db_txn_duration_ms = 0
+
+        # ms spent waiting for db txns to be scheduled
+        self.db_sched_duration_ms = 0
+
         self.usage_start = None
+        self.usage_end = None
         self.main_thread = threading.current_thread()
+        self.request = None
         self.tag = ""
         self.alive = True
 
@@ -95,7 +121,11 @@ class LoggingContext(object):
 
     @classmethod
     def current_context(cls):
-        """Get the current logging context from thread local storage"""
+        """Get the current logging context from thread local storage
+
+        Returns:
+            LoggingContext: the current logging context
+        """
         return getattr(cls.thread_local, "current_context", cls.sentinel)
 
     @classmethod
@@ -145,11 +175,13 @@ class LoggingContext(object):
         self.alive = False
 
     def copy_to(self, record):
-        """Copy fields from this context to the record"""
-        for key, value in self.__dict__.items():
-            setattr(record, key, value)
+        """Copy logging fields from this context to a log record or
+        another LoggingContext
+        """
 
-        record.ru_utime, record.ru_stime = self.get_resource_usage()
+        # 'request' is the only field we currently use in the logger, so that's
+        # all we need to copy
+        record.request = self.request
 
     def start(self):
         if threading.current_thread() is not self.main_thread:
@@ -184,7 +216,16 @@ class LoggingContext(object):
 
     def add_database_transaction(self, duration_ms):
         self.db_txn_count += 1
-        self.db_txn_duration += duration_ms / 1000.
+        self.db_txn_duration_ms += duration_ms
+
+    def add_database_scheduled(self, sched_ms):
+        """Record a use of the database pool
+
+        Args:
+            sched_ms (int): number of milliseconds it took us to get a
+                connection
+        """
+        self.db_sched_duration_ms += sched_ms
 
 
 class LoggingContextFilter(logging.Filter):
@@ -251,80 +292,94 @@ class PreserveLoggingContext(object):
                 )
 
 
-class _PreservingContextDeferred(defer.Deferred):
-    """A deferred that ensures that all callbacks and errbacks are called with
-    the given logging context.
-    """
-    def __init__(self, context):
-        self._log_context = context
-        defer.Deferred.__init__(self)
-
-    def addCallbacks(self, callback, errback=None,
-                     callbackArgs=None, callbackKeywords=None,
-                     errbackArgs=None, errbackKeywords=None):
-        callback = self._wrap_callback(callback)
-        errback = self._wrap_callback(errback)
-        return defer.Deferred.addCallbacks(
-            self, callback,
-            errback=errback,
-            callbackArgs=callbackArgs,
-            callbackKeywords=callbackKeywords,
-            errbackArgs=errbackArgs,
-            errbackKeywords=errbackKeywords,
-        )
+def preserve_fn(f):
+    """Function decorator which wraps the function with run_in_background"""
+    def g(*args, **kwargs):
+        return run_in_background(f, *args, **kwargs)
+    return g
 
-    def _wrap_callback(self, f):
-        def g(res, *args, **kwargs):
-            with PreserveLoggingContext(self._log_context):
-                res = f(res, *args, **kwargs)
-            return res
-        return g
 
+def run_in_background(f, *args, **kwargs):
+    """Calls a function, ensuring that the current context is restored after
+    return from the function, and that the sentinel context is set once the
+    deferred returned by the function completes.
 
-def preserve_context_over_fn(fn, *args, **kwargs):
-    """Takes a function and invokes it with the given arguments, but removes
-    and restores the current logging context while doing so.
+    Useful for wrapping functions that return a deferred which you don't yield
+    on (for instance because you want to pass it to deferred.gatherResults()).
 
-    If the result is a deferred, call preserve_context_over_deferred before
-    returning it.
+    Note that if you completely discard the result, you should make sure that
+    `f` doesn't raise any deferred exceptions, otherwise a scary-looking
+    CRITICAL error about an unhandled error will be logged without much
+    indication about where it came from.
     """
-    with PreserveLoggingContext():
-        res = fn(*args, **kwargs)
-
-    if isinstance(res, defer.Deferred):
-        return preserve_context_over_deferred(res)
-    else:
+    current = LoggingContext.current_context()
+    try:
+        res = f(*args, **kwargs)
+    except:   # noqa: E722
+        # the assumption here is that the caller doesn't want to be disturbed
+        # by synchronous exceptions, so let's turn them into Failures.
+        return defer.fail()
+
+    if not isinstance(res, defer.Deferred):
         return res
 
+    if res.called and not res.paused:
+        # The function should have maintained the logcontext, so we can
+        # optimise out the messing about
+        return res
 
-def preserve_context_over_deferred(deferred, context=None):
-    """Given a deferred wrap it such that any callbacks added later to it will
-    be invoked with the current context.
+    # The function may have reset the context before returning, so
+    # we need to restore it now.
+    ctx = LoggingContext.set_current_context(current)
+
+    # The original context will be restored when the deferred
+    # completes, but there is nothing waiting for it, so it will
+    # get leaked into the reactor or some other function which
+    # wasn't expecting it. We therefore need to reset the context
+    # here.
+    #
+    # (If this feels asymmetric, consider it this way: we are
+    # effectively forking a new thread of execution. We are
+    # probably currently within a ``with LoggingContext()`` block,
+    # which is supposed to have a single entry and exit point. But
+    # by spawning off another deferred, we are effectively
+    # adding a new exit point.)
+    res.addBoth(_set_context_cb, ctx)
+    return res
+
+
+def make_deferred_yieldable(deferred):
+    """Given a deferred, make it follow the Synapse logcontext rules:
+
+    If the deferred has completed (or is not actually a Deferred), essentially
+    does nothing (just returns another completed deferred with the
+    result/failure).
+
+    If the deferred has not yet completed, resets the logcontext before
+    returning a deferred. Then, when the deferred completes, restores the
+    current logcontext before running callbacks/errbacks.
+
+    (This is more-or-less the opposite operation to run_in_background.)
     """
-    if context is None:
-        context = LoggingContext.current_context()
-    d = _PreservingContextDeferred(context)
-    deferred.chainDeferred(d)
-    return d
+    if not isinstance(deferred, defer.Deferred):
+        return deferred
 
+    if deferred.called and not deferred.paused:
+        # it looks like this deferred is ready to run any callbacks we give it
+        # immediately. We may as well optimise out the logcontext faffery.
+        return deferred
 
-def preserve_fn(f):
-    """Ensures that function is called with correct context and that context is
-    restored after return. Useful for wrapping functions that return a deferred
-    which you don't yield on.
-    """
-    current = LoggingContext.current_context()
+    # ok, we can't be sure that a yield won't block, so let's reset the
+    # logcontext, and add a callback to the deferred to restore it.
+    prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
+    deferred.addBoth(_set_context_cb, prev_context)
+    return deferred
 
-    def g(*args, **kwargs):
-        with PreserveLoggingContext(current):
-            res = f(*args, **kwargs)
-            if isinstance(res, defer.Deferred):
-                return preserve_context_over_deferred(
-                    res, context=LoggingContext.sentinel
-                )
-            else:
-                return res
-    return g
+
+def _set_context_cb(result, context):
+    """A callback function which just sets the logging context"""
+    LoggingContext.set_current_context(context)
+    return result
 
 
 # modules to ignore in `logcontext_tracer`
diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py
new file mode 100644
index 0000000000..3e42868ea9
--- /dev/null
+++ b/synapse/util/logformatter.py
@@ -0,0 +1,51 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from six import StringIO
+import logging
+import traceback
+
+
+class LogFormatter(logging.Formatter):
+    """Log formatter which gives more detail for exceptions
+
+    This is the same as the standard log formatter, except that when logging
+    exceptions [typically via log.foo("msg", exc_info=1)], it prints the
+    sequence that led up to the point at which the exception was caught.
+    (Normally only stack frames between the point the exception was raised and
+    where it was caught are logged).
+    """
+    def __init__(self, *args, **kwargs):
+        super(LogFormatter, self).__init__(*args, **kwargs)
+
+    def formatException(self, ei):
+        sio = StringIO()
+        (typ, val, tb) = ei
+
+        # log the stack above the exception capture point if possible, but
+        # check that we actually have an f_back attribute to work around
+        # https://twistedmatrix.com/trac/ticket/9305
+
+        if tb and hasattr(tb.tb_frame, 'f_back'):
+            sio.write("Capture point (most recent call last):\n")
+            traceback.print_stack(tb.tb_frame.f_back, None, sio)
+
+        traceback.print_exception(typ, val, tb, None, sio)
+        s = sio.getvalue()
+        sio.close()
+        if s[-1:] == "\n":
+            s = s[:-1]
+        return s
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 4ea930d3e8..e4b5687a4b 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -27,25 +27,62 @@ logger = logging.getLogger(__name__)
 
 metrics = synapse.metrics.get_metrics_for(__name__)
 
-block_timer = metrics.register_distribution(
-    "block_timer",
-    labels=["block_name"]
+# total number of times we have hit this block
+block_counter = metrics.register_counter(
+    "block_count",
+    labels=["block_name"],
+    alternative_names=(
+        # the following are all deprecated aliases for the same metric
+        metrics.name_prefix + x for x in (
+            "_block_timer:count",
+            "_block_ru_utime:count",
+            "_block_ru_stime:count",
+            "_block_db_txn_count:count",
+            "_block_db_txn_duration:count",
+        )
+    )
+)
+
+block_timer = metrics.register_counter(
+    "block_time_seconds",
+    labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_timer:total",
+    ),
 )
 
-block_ru_utime = metrics.register_distribution(
-    "block_ru_utime", labels=["block_name"]
+block_ru_utime = metrics.register_counter(
+    "block_ru_utime_seconds", labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_ru_utime:total",
+    ),
 )
 
-block_ru_stime = metrics.register_distribution(
-    "block_ru_stime", labels=["block_name"]
+block_ru_stime = metrics.register_counter(
+    "block_ru_stime_seconds", labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_ru_stime:total",
+    ),
 )
 
-block_db_txn_count = metrics.register_distribution(
-    "block_db_txn_count", labels=["block_name"]
+block_db_txn_count = metrics.register_counter(
+    "block_db_txn_count", labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_db_txn_count:total",
+    ),
 )
 
-block_db_txn_duration = metrics.register_distribution(
-    "block_db_txn_duration", labels=["block_name"]
+# seconds spent waiting for db txns, excluding scheduling time, in this block
+block_db_txn_duration = metrics.register_counter(
+    "block_db_txn_duration_seconds", labels=["block_name"],
+    alternative_names=(
+        metrics.name_prefix + "_block_db_txn_duration:total",
+    ),
+)
+
+# seconds spent waiting for a db connection, in this block
+block_db_sched_duration = metrics.register_counter(
+    "block_db_sched_duration_seconds", labels=["block_name"],
 )
 
 
@@ -64,7 +101,9 @@ def measure_func(name):
 class Measure(object):
     __slots__ = [
         "clock", "name", "start_context", "start", "new_context", "ru_utime",
-        "ru_stime", "db_txn_count", "db_txn_duration", "created_context"
+        "ru_stime",
+        "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
+        "created_context",
     ]
 
     def __init__(self, clock, name):
@@ -84,13 +123,16 @@ class Measure(object):
 
         self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
         self.db_txn_count = self.start_context.db_txn_count
-        self.db_txn_duration = self.start_context.db_txn_duration
+        self.db_txn_duration_ms = self.start_context.db_txn_duration_ms
+        self.db_sched_duration_ms = self.start_context.db_sched_duration_ms
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         if isinstance(exc_type, Exception) or not self.start_context:
             return
 
         duration = self.clock.time_msec() - self.start
+
+        block_counter.inc(self.name)
         block_timer.inc_by(duration, self.name)
 
         context = LoggingContext.current_context()
@@ -114,7 +156,12 @@ class Measure(object):
             context.db_txn_count - self.db_txn_count, self.name
         )
         block_db_txn_duration.inc_by(
-            context.db_txn_duration - self.db_txn_duration, self.name
+            (context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000.,
+            self.name
+        )
+        block_db_sched_duration.inc_by(
+            (context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000.,
+            self.name
         )
 
         if self.created_context:
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
new file mode 100644
index 0000000000..4288312b8a
--- /dev/null
+++ b/synapse/util/module_loader.py
@@ -0,0 +1,42 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+
+from synapse.config._base import ConfigError
+
+
+def load_module(provider):
+    """ Loads a module with its config
+    Take a dict with keys 'module' (the module name) and 'config'
+    (the config dict).
+
+    Returns
+        Tuple of (provider class, parsed config object)
+    """
+    # We need to import the module, and then pick the class out of
+    # that, so we split based on the last dot.
+    module, clz = provider['module'].rsplit(".", 1)
+    module = importlib.import_module(module)
+    provider_class = getattr(module, clz)
+
+    try:
+        provider_config = provider_class.parse_config(provider["config"])
+    except Exception as e:
+        raise ConfigError(
+            "Failed to parse config for %r: %r" % (provider['module'], e)
+        )
+
+    return provider_class, provider_config
diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py
new file mode 100644
index 0000000000..607161e7f0
--- /dev/null
+++ b/synapse/util/msisdn.py
@@ -0,0 +1,40 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import phonenumbers
+from synapse.api.errors import SynapseError
+
+
+def phone_number_to_msisdn(country, number):
+    """
+    Takes an ISO-3166-1 2 letter country code and phone number and
+    returns an msisdn representing the canonical version of that
+    phone number.
+    Args:
+        country (str): ISO-3166-1 2 letter country code
+        number (str): Phone number in a national or international format
+
+    Returns:
+        (str) The canonical form of the phone number, as an msisdn
+    Raises:
+            SynapseError if the number could not be parsed.
+    """
+    try:
+        phoneNumber = phonenumbers.parse(number, country)
+    except phonenumbers.NumberParseException:
+        raise SynapseError(400, "Unable to parse phone number")
+    return phonenumbers.format_number(
+        phoneNumber, phonenumbers.PhoneNumberFormat.E164
+    )[1:]
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 1101881a2d..18424f6c36 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -18,7 +18,7 @@ from twisted.internet import defer
 from synapse.api.errors import LimitExceededError
 
 from synapse.util.async import sleep
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import run_in_background
 
 import collections
 import contextlib
@@ -150,7 +150,7 @@ class _PerHostRatelimiter(object):
                 "Ratelimit [%s]: sleeping req",
                 id(request_id),
             )
-            ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0)
+            ret_defer = run_in_background(sleep, self.sleep_msec / 1000.0)
 
             self.sleeping_requests.add(request_id)
 
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 153ef001ad..4e93f69d3a 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import synapse.util.logcontext
 from twisted.internet import defer
 
 from synapse.api.errors import CodeMessageException
@@ -26,6 +26,18 @@ logger = logging.getLogger(__name__)
 
 class NotRetryingDestination(Exception):
     def __init__(self, retry_last_ts, retry_interval, destination):
+        """Raised by the limiter (and federation client) to indicate that we are
+        are deliberately not attempting to contact a given server.
+
+        Args:
+            retry_last_ts (int): the unix ts in milliseconds of our last attempt
+                to contact the server.  0 indicates that the last attempt was
+                successful or that we've never actually attempted to connect.
+            retry_interval (int): the time in milliseconds to wait until the next
+                attempt.
+            destination (str): the domain in question
+        """
+
         msg = "Not retrying server %s." % (destination,)
         super(NotRetryingDestination, self).__init__(msg)
 
@@ -35,7 +47,8 @@ class NotRetryingDestination(Exception):
 
 
 @defer.inlineCallbacks
-def get_retry_limiter(destination, clock, store, **kwargs):
+def get_retry_limiter(destination, clock, store, ignore_backoff=False,
+                      **kwargs):
     """For a given destination check if we have previously failed to
     send a request there and are waiting before retrying the destination.
     If we are not ready to retry the destination, this will raise a
@@ -43,6 +56,14 @@ def get_retry_limiter(destination, clock, store, **kwargs):
     that will mark the destination as down if an exception is thrown (excluding
     CodeMessageException with code < 500)
 
+    Args:
+        destination (str): name of homeserver
+        clock (synapse.util.clock): timing source
+        store (synapse.storage.transactions.TransactionStore): datastore
+        ignore_backoff (bool): true to ignore the historical backoff data and
+            try the request anyway. We will still update the next
+            retry_interval on success/failure.
+
     Example usage:
 
         try:
@@ -66,7 +87,7 @@ def get_retry_limiter(destination, clock, store, **kwargs):
 
         now = int(clock.time_msec())
 
-        if retry_last_ts + retry_interval > now:
+        if not ignore_backoff and retry_last_ts + retry_interval > now:
             raise NotRetryingDestination(
                 retry_last_ts=retry_last_ts,
                 retry_interval=retry_interval,
@@ -124,7 +145,13 @@ class RetryDestinationLimiter(object):
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         valid_err_code = False
-        if exc_type is not None and issubclass(exc_type, CodeMessageException):
+        if exc_type is None:
+            valid_err_code = True
+        elif not issubclass(exc_type, Exception):
+            # avoid treating exceptions which don't derive from Exception as
+            # failures; this is mostly so as not to catch defer._DefGen.
+            valid_err_code = True
+        elif issubclass(exc_type, CodeMessageException):
             # Some error codes are perfectly fine for some APIs, whereas other
             # APIs may expect to never received e.g. a 404. It's important to
             # handle 404 as some remote servers will return a 404 when the HS
@@ -142,11 +169,13 @@ class RetryDestinationLimiter(object):
             else:
                 valid_err_code = False
 
-        if exc_type is None or valid_err_code:
+        if valid_err_code:
             # We connected successfully.
             if not self.retry_interval:
                 return
 
+            logger.debug("Connection to %s was successful; clearing backoff",
+                         self.destination)
             retry_last_ts = 0
             self.retry_interval = 0
         else:
@@ -160,6 +189,10 @@ class RetryDestinationLimiter(object):
             else:
                 self.retry_interval = self.min_retry_interval
 
+            logger.debug(
+                "Connection to %s was unsuccessful (%s(%s)); backoff now %i",
+                self.destination, exc_type, exc_val, self.retry_interval
+            )
             retry_last_ts = int(self.clock.time_msec())
 
         @defer.inlineCallbacks
@@ -168,9 +201,10 @@ class RetryDestinationLimiter(object):
                 yield self.store.set_destination_retry_timings(
                     self.destination, retry_last_ts, self.retry_interval
                 )
-            except:
+            except Exception:
                 logger.exception(
-                    "Failed to store set_destination_retry_timings",
+                    "Failed to store destination_retry_timings",
                 )
 
-        store_retry_timings()
+        # we deliberately do this in the background.
+        synapse.util.logcontext.run_in_background(store_retry_timings)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index a100f151d4..b98b9dc6e4 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -15,6 +15,7 @@
 
 import random
 import string
+from six.moves import range
 
 _string_with_symbols = (
     string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
@@ -22,12 +23,12 @@ _string_with_symbols = (
 
 
 def random_string(length):
-    return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
+    return ''.join(random.choice(string.ascii_letters) for _ in range(length))
 
 
 def random_string_with_symbols(length):
     return ''.join(
-        random.choice(_string_with_symbols) for _ in xrange(length)
+        random.choice(_string_with_symbols) for _ in range(length)
     )
 
 
@@ -40,3 +41,17 @@ def is_ascii(s):
         return False
     else:
         return True
+
+
+def to_ascii(s):
+    """Converts a string to ascii if it is ascii, otherwise leave it alone.
+
+    If given None then will return None.
+    """
+    if s is None:
+        return None
+
+    try:
+        return s.encode("ascii")
+    except UnicodeEncodeError:
+        return s
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
new file mode 100644
index 0000000000..75efa0117b
--- /dev/null
+++ b/synapse/util/threepids.py
@@ -0,0 +1,48 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+logger = logging.getLogger(__name__)
+
+
+def check_3pid_allowed(hs, medium, address):
+    """Checks whether a given format of 3PID is allowed to be used on this HS
+
+    Args:
+        hs (synapse.server.HomeServer): server
+        medium (str): 3pid medium - e.g. email, msisdn
+        address (str): address within that medium (e.g. "wotan@matrix.org")
+            msisdns need to first have been canonicalised
+    Returns:
+        bool: whether the 3PID medium/address is allowed to be added to this HS
+    """
+
+    if hs.config.allowed_local_3pids:
+        for constraint in hs.config.allowed_local_3pids:
+            logger.debug(
+                "Checking 3PID %s (%s) against %s (%s)",
+                address, medium, constraint['pattern'], constraint['medium'],
+            )
+            if (
+                medium == constraint['medium'] and
+                re.match(constraint['pattern'], address)
+            ):
+                return True
+    else:
+        return True
+
+    return False
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index 7412fc57a4..7a9e45aca9 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -13,6 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from six.moves import range
+
 
 class _Entry(object):
     __slots__ = ["end_key", "queue"]
@@ -68,7 +70,7 @@ class WheelTimer(object):
         # Add empty entries between the end of the current list and when we want
         # to insert. This ensures there are no gaps.
         self.entries.extend(
-            _Entry(key) for key in xrange(last_key, then_key + 1)
+            _Entry(key) for key in range(last_key, then_key + 1)
         )
 
         self.entries[-1].queue.append(obj)
@@ -91,7 +93,4 @@ class WheelTimer(object):
         return ret
 
     def __len__(self):
-        l = 0
-        for entry in self.entries:
-            l += len(entry.queue)
-        return l
+        return sum(len(entry.queue) for entry in self.entries)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 199b16d827..aaca2c584c 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -17,7 +17,7 @@ from twisted.internet import defer
 
 from synapse.api.constants import Membership, EventTypes
 
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
 
 import logging
 
@@ -43,7 +43,8 @@ MEMBERSHIP_PRIORITY = (
 
 
 @defer.inlineCallbacks
-def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
+def filter_events_for_clients(store, user_tuples, events, event_id_to_state,
+                              always_include_ids=frozenset()):
     """ Returns dict of user_id -> list of events that user is allowed to
     see.
 
@@ -54,9 +55,12 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
             * the user has not been a member of the room since the
             given events
         events ([synapse.events.EventBase]): list of events to filter
+        always_include_ids (set(event_id)): set of event ids to specifically
+            include (unless sender is ignored)
     """
-    forgotten = yield preserve_context_over_deferred(defer.gatherResults([
-        preserve_fn(store.who_forgot_in_room)(
+    forgotten = yield make_deferred_yieldable(defer.gatherResults([
+        defer.maybeDeferred(
+            preserve_fn(store.who_forgot_in_room),
             room_id,
         )
         for room_id in frozenset(e.room_id for e in events)
@@ -90,6 +94,9 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
         if not event.is_state() and event.sender in ignore_list:
             return False
 
+        if event.event_id in always_include_ids:
+            return True
+
         state = event_id_to_state[event.event_id]
 
         # get the room_visibility at the time of the event.
@@ -134,6 +141,13 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
             if prev_membership not in MEMBERSHIP_PRIORITY:
                 prev_membership = "leave"
 
+            # Always allow the user to see their own leave events, otherwise
+            # they won't see the room disappear if they reject the invite
+            if membership == "leave" and (
+                prev_membership == "join" or prev_membership == "invite"
+            ):
+                return True
+
             new_priority = MEMBERSHIP_PRIORITY.index(membership)
             old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
             if old_priority < new_priority:
@@ -181,26 +195,8 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
 
 
 @defer.inlineCallbacks
-def filter_events_for_clients_context(store, user_tuples, events, event_id_to_context):
-    user_ids = set(u[0] for u in user_tuples)
-    event_id_to_state = {}
-    for event_id, context in event_id_to_context.items():
-        state = yield store.get_events([
-            e_id
-            for key, e_id in context.current_state_ids.iteritems()
-            if key == (EventTypes.RoomHistoryVisibility, "")
-            or (key[0] == EventTypes.Member and key[1] in user_ids)
-        ])
-        event_id_to_state[event_id] = state
-
-    res = yield filter_events_for_clients(
-        store, user_tuples, events, event_id_to_state
-    )
-    defer.returnValue(res)
-
-
-@defer.inlineCallbacks
-def filter_events_for_client(store, user_id, events, is_peeking=False):
+def filter_events_for_client(store, user_id, events, is_peeking=False,
+                             always_include_ids=frozenset()):
     """
     Check which events a user is allowed to see
 
@@ -224,6 +220,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False):
         types=types
     )
     res = yield filter_events_for_clients(
-        store, [(user_id, is_peeking)], events, event_id_to_state
+        store, [(user_id, is_peeking)], events, event_id_to_state,
+        always_include_ids=always_include_ids,
     )
     defer.returnValue(res.get(user_id, []))