summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2017-01-30 14:36:46 +0000
committerErik Johnston <erik@matrix.org>2017-01-30 14:36:46 +0000
commit717e4448c4b9159e002d835dc1250d5b5a19a1d2 (patch)
treea84439e6cb76fa537391c07ef61e8ddd97250a7a
parentMerge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes (diff)
parentMerge pull request #1857 from matrix-org/erikj/device_list_stream (diff)
downloadsynapse-717e4448c4b9159e002d835dc1250d5b5a19a1d2.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes
-rw-r--r--docs/admin_api/purge_remote_media.rst8
-rw-r--r--synapse/app/federation_sender.py3
-rw-r--r--synapse/app/synchrotron.py26
-rw-r--r--synapse/federation/federation_client.py10
-rw-r--r--synapse/federation/federation_server.py3
-rw-r--r--synapse/federation/transaction_queue.py133
-rw-r--r--synapse/federation/transport/client.py26
-rw-r--r--synapse/federation/transport/server.py8
-rw-r--r--synapse/handlers/_base.py8
-rw-r--r--synapse/handlers/device.py141
-rw-r--r--synapse/handlers/e2e_keys.py43
-rw-r--r--synapse/handlers/federation.py1
-rw-r--r--synapse/handlers/message.py6
-rw-r--r--synapse/handlers/sync.py28
-rw-r--r--synapse/replication/resource.py20
-rw-r--r--synapse/replication/slave/storage/devices.py72
-rw-r--r--synapse/replication/slave/storage/events.py10
-rw-r--r--synapse/rest/client/v2_alpha/sync.py6
-rw-r--r--synapse/state.py3
-rw-r--r--synapse/storage/__init__.py11
-rw-r--r--synapse/storage/_base.py6
-rw-r--r--synapse/storage/devices.py455
-rw-r--r--synapse/storage/end_to_end_keys.py31
-rw-r--r--synapse/storage/event_federation.py76
-rw-r--r--synapse/storage/events.py383
-rw-r--r--synapse/storage/schema/delta/40/device_list_streams.sql59
-rw-r--r--synapse/storage/state.py52
-rw-r--r--synapse/streams/events.py4
-rw-r--r--synapse/types.py2
-rw-r--r--tests/handlers/test_device.py18
-rw-r--r--tests/handlers/test_directory.py1
-rw-r--r--tests/handlers/test_profile.py1
-rw-r--r--tests/handlers/test_typing.py3
-rw-r--r--tests/replication/slave/storage/test_events.py72
-rw-r--r--tests/rest/client/v1/test_rooms.py4
-rw-r--r--tests/storage/test_appservice.py21
-rw-r--r--tests/storage/test_end_to_end_keys.py17
37 files changed, 1340 insertions, 431 deletions
diff --git a/docs/admin_api/purge_remote_media.rst b/docs/admin_api/purge_remote_media.rst
index b26c6a9e7b..5deb02a3df 100644
--- a/docs/admin_api/purge_remote_media.rst
+++ b/docs/admin_api/purge_remote_media.rst
@@ -2,15 +2,13 @@ Purge Remote Media API
 ======================
 
 The purge remote media API allows server admins to purge old cached remote
-media. 
+media.
 
 The API is::
 
-    POST /_matrix/client/r0/admin/purge_media_cache
+    POST /_matrix/client/r0/admin/purge_media_cache?before_ts=<unix_timestamp_in_ms>&access_token=<access_token>
 
-    {
-        "before_ts": <unix_timestamp_in_ms>
-    }
+    {}
 
 Which will remove all cached media that was last accessed before
 ``<unix_timestamp_in_ms>``.
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index ec06620efb..411e47d98d 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -30,6 +30,7 @@ from synapse.replication.slave.storage.events import SlavedEventStore
 from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
 from synapse.replication.slave.storage.registration import SlavedRegistrationStore
 from synapse.replication.slave.storage.transactions import TransactionStore
+from synapse.replication.slave.storage.devices import SlavedDeviceStore
 from synapse.storage.engines import create_engine
 from synapse.storage.presence import UserPresenceState
 from synapse.util.async import sleep
@@ -56,7 +57,7 @@ logger = logging.getLogger("synapse.app.appservice")
 
 class FederationSenderSlaveStore(
     SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
-    SlavedRegistrationStore,
+    SlavedRegistrationStore, SlavedDeviceStore,
 ):
     pass
 
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 4dfc2dc648..b3fb408cfd 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -39,6 +39,7 @@ from synapse.replication.slave.storage.filtering import SlavedFilteringStore
 from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
 from synapse.replication.slave.storage.presence import SlavedPresenceStore
 from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
+from synapse.replication.slave.storage.devices import SlavedDeviceStore
 from synapse.replication.slave.storage.room import RoomStore
 from synapse.server import HomeServer
 from synapse.storage.client_ips import ClientIpStore
@@ -77,6 +78,7 @@ class SynchrotronSlavedStore(
     SlavedFilteringStore,
     SlavedPresenceStore,
     SlavedDeviceInboxStore,
+    SlavedDeviceStore,
     RoomStore,
     BaseSlavedStore,
     ClientIpStore,  # After BaseSlavedStore because the constructor is different
@@ -380,6 +382,27 @@ class SynchrotronServer(HomeServer):
                         stream_key, position, users=users, rooms=rooms
                     )
 
+        @defer.inlineCallbacks
+        def notify_device_list_update(result):
+            stream = result.get("device_lists")
+            if not stream:
+                return
+
+            position_index = stream["field_names"].index("position")
+            user_index = stream["field_names"].index("user_id")
+
+            for row in stream["rows"]:
+                position = row[position_index]
+                user_id = row[user_index]
+
+                rooms = yield store.get_rooms_for_user(user_id)
+                room_ids = [r.room_id for r in rooms]
+
+                notifier.on_new_event(
+                    "device_list_key", position, rooms=room_ids,
+                )
+
+        @defer.inlineCallbacks
         def notify(result):
             stream = result.get("events")
             if stream:
@@ -417,6 +440,7 @@ class SynchrotronServer(HomeServer):
             notify_from_stream(
                 result, "to_device", "to_device_key", user="user_id"
             )
+            yield notify_device_list_update(result)
 
         while True:
             try:
@@ -427,7 +451,7 @@ class SynchrotronServer(HomeServer):
                 yield store.process_replication(result)
                 typing_handler.process_replication(result)
                 yield presence_handler.process_replication(result)
-                notify(result)
+                yield notify(result)
             except:
                 logger.exception("Error replicating from %r", replication_url)
                 yield sleep(5)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index c9175bb33d..b5bcfd705a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -127,6 +127,16 @@ class FederationClient(FederationBase):
         )
 
     @log_function
+    def query_user_devices(self, destination, user_id, timeout=30000):
+        """Query the device keys for a list of user ids hosted on a remote
+        server.
+        """
+        sent_queries_counter.inc("user_devices")
+        return self.transport_layer.query_user_devices(
+            destination, user_id, timeout
+        )
+
+    @log_function
     def claim_client_keys(self, destination, content, timeout):
         """Claims one-time keys for a device hosted on a remote server.
 
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 862ccbef5d..e922b7ff4a 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -416,6 +416,9 @@ class FederationServer(FederationBase):
     def on_query_client_keys(self, origin, content):
         return self.on_query_request("client_keys", content)
 
+    def on_query_user_devices(self, origin, user_id):
+        return self.on_query_request("user_devices", user_id)
+
     @defer.inlineCallbacks
     @log_function
     def on_claim_client_keys(self, origin, content):
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 6b3a7abb9e..d18f6b6cfd 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -100,6 +100,7 @@ class TransactionQueue(object):
         self.pending_failures_by_dest = {}
 
         self.last_device_stream_id_by_dest = {}
+        self.last_device_list_stream_id_by_dest = {}
 
         # HACK to get unique tx id
         self._next_txn_id = int(self.clock.time_msec())
@@ -305,62 +306,74 @@ class TransactionQueue(object):
             yield run_on_reactor()
 
             while True:
-                    pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
-                    pending_edus = self.pending_edus_by_dest.pop(destination, [])
-                    pending_presence = self.pending_presence_by_dest.pop(destination, {})
-                    pending_failures = self.pending_failures_by_dest.pop(destination, [])
+                pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
+                pending_edus = self.pending_edus_by_dest.pop(destination, [])
+                pending_presence = self.pending_presence_by_dest.pop(destination, {})
+                pending_failures = self.pending_failures_by_dest.pop(destination, [])
 
-                    pending_edus.extend(
-                        self.pending_edus_keyed_by_dest.pop(destination, {}).values()
-                    )
+                pending_edus.extend(
+                    self.pending_edus_keyed_by_dest.pop(destination, {}).values()
+                )
 
-                    limiter = yield get_retry_limiter(
-                        destination,
-                        self.clock,
-                        self.store,
-                    )
+                limiter = yield get_retry_limiter(
+                    destination,
+                    self.clock,
+                    self.store,
+                )
 
-                    device_message_edus, device_stream_id = (
-                        yield self._get_new_device_messages(destination)
-                    )
+                device_message_edus, device_stream_id, dev_list_id = (
+                    yield self._get_new_device_messages(destination)
+                )
 
-                    pending_edus.extend(device_message_edus)
-                    if pending_presence:
-                        pending_edus.append(
-                            Edu(
-                                origin=self.server_name,
-                                destination=destination,
-                                edu_type="m.presence",
-                                content={
-                                    "push": [
-                                        format_user_presence_state(
-                                            presence, self.clock.time_msec()
-                                        )
-                                        for presence in pending_presence.values()
-                                    ]
-                                },
-                            )
+                pending_edus.extend(device_message_edus)
+                if pending_presence:
+                    pending_edus.append(
+                        Edu(
+                            origin=self.server_name,
+                            destination=destination,
+                            edu_type="m.presence",
+                            content={
+                                "push": [
+                                    format_user_presence_state(
+                                        presence, self.clock.time_msec()
+                                    )
+                                    for presence in pending_presence.values()
+                                ]
+                            },
                         )
+                    )
+
+                if pending_pdus:
+                    logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+                                 destination, len(pending_pdus))
 
-                    if pending_pdus:
-                        logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
-                                     destination, len(pending_pdus))
+                if not pending_pdus and not pending_edus and not pending_failures:
+                    logger.debug("TX [%s] Nothing to send", destination)
+                    self.last_device_stream_id_by_dest[destination] = (
+                        device_stream_id
+                    )
+                    return
 
-                    if not pending_pdus and not pending_edus and not pending_failures:
-                        logger.debug("TX [%s] Nothing to send", destination)
-                        self.last_device_stream_id_by_dest[destination] = (
-                            device_stream_id
+                success = yield self._send_new_transaction(
+                    destination, pending_pdus, pending_edus, pending_failures,
+                    limiter=limiter,
+                )
+                if success:
+                    # Remove the acknowledged device messages from the database
+                    # Only bother if we actually sent some device messages
+                    if device_message_edus:
+                        yield self.store.delete_device_msgs_for_remote(
+                            destination, device_stream_id
+                        )
+                        logger.info("Marking as sent %r %r", destination, dev_list_id)
+                        yield self.store.mark_as_sent_devices_by_remote(
+                            destination, dev_list_id
                         )
-                        return
 
-                    success = yield self._send_new_transaction(
-                        destination, pending_pdus, pending_edus, pending_failures,
-                        device_stream_id,
-                        should_delete_from_device_stream=bool(device_message_edus),
-                        limiter=limiter,
-                    )
-                    if not success:
-                        break
+                    self.last_device_stream_id_by_dest[destination] = device_stream_id
+                    self.last_device_list_stream_id_by_dest[destination] = dev_list_id
+                else:
+                    break
         except NotRetryingDestination:
             logger.debug(
                 "TX [%s] not ready for retry yet - "
@@ -387,13 +400,26 @@ class TransactionQueue(object):
             )
             for content in contents
         ]
-        defer.returnValue((edus, stream_id))
+
+        last_device_list = self.last_device_list_stream_id_by_dest.get(destination, 0)
+        now_stream_id, results = yield self.store.get_devices_by_remote(
+            destination, last_device_list
+        )
+        edus.extend(
+            Edu(
+                origin=self.server_name,
+                destination=destination,
+                edu_type="m.device_list_update",
+                content=content,
+            )
+            for content in results
+        )
+        defer.returnValue((edus, stream_id, now_stream_id))
 
     @measure_func("_send_new_transaction")
     @defer.inlineCallbacks
     def _send_new_transaction(self, destination, pending_pdus, pending_edus,
-                              pending_failures, device_stream_id,
-                              should_delete_from_device_stream, limiter):
+                              pending_failures, limiter):
 
         # Sort based on the order field
         pending_pdus.sort(key=lambda t: t[1])
@@ -504,13 +530,6 @@ class TransactionQueue(object):
                         "Failed to send event %s to %s", p.event_id, destination
                     )
                 success = False
-            else:
-                # Remove the acknowledged device messages from the database
-                if should_delete_from_device_stream:
-                    yield self.store.delete_device_msgs_for_remote(
-                        destination, device_stream_id
-                    )
-                self.last_device_stream_id_by_dest[destination] = device_stream_id
         except RuntimeError as e:
             # We capture this here as there as nothing actually listens
             # for this finishing functions deferred.
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 915af34409..f49e8a2cc4 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -348,6 +348,32 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
+    def query_user_devices(self, destination, user_id, timeout):
+        """Query the devices for a user id hosted on a remote server.
+
+        Response:
+            {
+              "stream_id": "...",
+              "devices": [ { ... } ]
+            }
+
+        Args:
+            destination(str): The server to query.
+            query_content(dict): The user ids to query.
+        Returns:
+            A dict containg the device keys.
+        """
+        path = PREFIX + "/user/devices/" + user_id
+
+        content = yield self.client.get_json(
+            destination=destination,
+            path=path,
+            timeout=timeout,
+        )
+        defer.returnValue(content)
+
+    @defer.inlineCallbacks
+    @log_function
     def claim_client_keys(self, destination, query_content, timeout):
         """Claim one-time keys for a list of devices hosted on a remote server.
 
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 159dbd1747..c840da834c 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -409,6 +409,13 @@ class FederationClientKeysQueryServlet(BaseFederationServlet):
         return self.handler.on_query_client_keys(origin, content)
 
 
+class FederationUserDevicesQueryServlet(BaseFederationServlet):
+    PATH = "/user/devices/(?P<user_id>[^/]*)"
+
+    def on_GET(self, origin, content, query, user_id):
+        return self.handler.on_query_user_devices(origin, user_id)
+
+
 class FederationClientKeysClaimServlet(BaseFederationServlet):
     PATH = "/user/keys/claim"
 
@@ -613,6 +620,7 @@ SERVLET_CLASSES = (
     FederationGetMissingEventsServlet,
     FederationEventAuthServlet,
     FederationClientKeysQueryServlet,
+    FederationUserDevicesQueryServlet,
     FederationClientKeysClaimServlet,
     FederationThirdPartyInviteExchangeServlet,
     On3pidBindServlet,
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 90f96209f8..e83adc8339 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -88,9 +88,13 @@ class BaseHandler(object):
                     current_state = yield self.store.get_events(
                         context.current_state_ids.values()
                     )
-                    current_state = current_state.values()
                 else:
-                    current_state = yield self.store.get_current_state(event.room_id)
+                    current_state = yield self.state_handler.get_current_state(
+                        event.room_id
+                    )
+
+                current_state = current_state.values()
+
                 logger.info("maybe_kick_guest_users %r", current_state)
                 yield self.kick_guest_users(current_state)
 
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index aa68755936..6fefb85890 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -15,6 +15,8 @@
 
 from synapse.api import errors
 from synapse.util import stringutils
+from synapse.util.async import Linearizer
+from synapse.types import get_domain_from_id
 from twisted.internet import defer
 from ._base import BaseHandler
 
@@ -27,6 +29,21 @@ class DeviceHandler(BaseHandler):
     def __init__(self, hs):
         super(DeviceHandler, self).__init__(hs)
 
+        self.hs = hs
+        self.state = hs.get_state_handler()
+        self.federation_sender = hs.get_federation_sender()
+        self.federation = hs.get_replication_layer()
+        self._remote_edue_linearizer = Linearizer(name="remote_device_list")
+
+        self.federation.register_edu_handler(
+            "m.device_list_update", self._incoming_device_list_update,
+        )
+        self.federation.register_query_handler(
+            "user_devices", self.on_federation_query_user_devices,
+        )
+
+        hs.get_distributor().observe("user_left_room", self.user_left_room)
+
     @defer.inlineCallbacks
     def check_device_registered(self, user_id, device_id,
                                 initial_device_display_name=None):
@@ -45,29 +62,29 @@ class DeviceHandler(BaseHandler):
             str: device id (generated if none was supplied)
         """
         if device_id is not None:
-            yield self.store.store_device(
+            new_device = yield self.store.store_device(
                 user_id=user_id,
                 device_id=device_id,
                 initial_device_display_name=initial_device_display_name,
-                ignore_if_known=True,
             )
+            if new_device:
+                yield self.notify_device_update(user_id, [device_id])
             defer.returnValue(device_id)
 
         # if the device id is not specified, we'll autogen one, but loop a few
         # times in case of a clash.
         attempts = 0
         while attempts < 5:
-            try:
-                device_id = stringutils.random_string(10).upper()
-                yield self.store.store_device(
-                    user_id=user_id,
-                    device_id=device_id,
-                    initial_device_display_name=initial_device_display_name,
-                    ignore_if_known=False,
-                )
+            device_id = stringutils.random_string(10).upper()
+            new_device = yield self.store.store_device(
+                user_id=user_id,
+                device_id=device_id,
+                initial_device_display_name=initial_device_display_name,
+            )
+            if new_device:
+                yield self.notify_device_update(user_id, [device_id])
                 defer.returnValue(device_id)
-            except errors.StoreError:
-                attempts += 1
+            attempts += 1
 
         raise errors.StoreError(500, "Couldn't generate a device ID.")
 
@@ -147,6 +164,8 @@ class DeviceHandler(BaseHandler):
             user_id=user_id, device_id=device_id
         )
 
+        yield self.notify_device_update(user_id, [device_id])
+
     @defer.inlineCallbacks
     def update_device(self, user_id, device_id, content):
         """ Update the given device
@@ -166,12 +185,110 @@ class DeviceHandler(BaseHandler):
                 device_id,
                 new_display_name=content.get("display_name")
             )
+            yield self.notify_device_update(user_id, [device_id])
         except errors.StoreError, e:
             if e.code == 404:
                 raise errors.NotFoundError()
             else:
                 raise
 
+    @defer.inlineCallbacks
+    def notify_device_update(self, user_id, device_ids):
+        """Notify that a user's device(s) has changed. Pokes the notifier, and
+        remote servers if the user is local.
+        """
+        rooms = yield self.store.get_rooms_for_user(user_id)
+        room_ids = [r.room_id for r in rooms]
+
+        hosts = set()
+        if self.hs.is_mine_id(user_id):
+            for room_id in room_ids:
+                users = yield self.state.get_current_user_in_room(room_id)
+                hosts.update(get_domain_from_id(u) for u in users)
+            hosts.discard(self.server_name)
+
+        position = yield self.store.add_device_change_to_streams(
+            user_id, device_ids, list(hosts)
+        )
+
+        yield self.notifier.on_new_event(
+            "device_list_key", position, rooms=room_ids,
+        )
+
+        if hosts:
+            logger.info("Sending device list update notif to: %r", hosts)
+            for host in hosts:
+                self.federation_sender.send_device_messages(host)
+
+    @defer.inlineCallbacks
+    def _incoming_device_list_update(self, origin, edu_content):
+        user_id = edu_content["user_id"]
+        device_id = edu_content["device_id"]
+        stream_id = edu_content["stream_id"]
+        prev_ids = edu_content.get("prev_id", [])
+
+        if get_domain_from_id(user_id) != origin:
+            # TODO: Raise?
+            logger.warning("Got device list update edu for %r from %r", user_id, origin)
+            return
+
+        rooms = yield self.store.get_rooms_for_user(user_id)
+        if not rooms:
+            # We don't share any rooms with this user. Ignore update, as we
+            # probably won't get any further updates.
+            return
+
+        with (yield self._remote_edue_linearizer.queue(user_id)):
+            # If the prev id matches whats in our cache table, then we don't need
+            # to resync the users device list, otherwise we do.
+            resync = True
+            if len(prev_ids) == 1:
+                extremity = yield self.store.get_device_list_last_stream_id_for_remote(
+                    user_id
+                )
+                logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids)
+                if str(extremity) == str(prev_ids[0]):
+                    resync = False
+
+            if resync:
+                # Fetch all devices for the user.
+                result = yield self.federation.query_user_devices(origin, user_id)
+                stream_id = result["stream_id"]
+                devices = result["devices"]
+                yield self.store.update_remote_device_list_cache(
+                    user_id, devices, stream_id,
+                )
+                device_ids = [device["device_id"] for device in devices]
+                yield self.notify_device_update(user_id, device_ids)
+            else:
+                # Simply update the single device, since we know that is the only
+                # change (becuase of the single prev_id matching the current cache)
+                content = dict(edu_content)
+                for key in ("user_id", "device_id", "stream_id", "prev_ids"):
+                    content.pop(key, None)
+                yield self.store.update_remote_device_list_cache_entry(
+                    user_id, device_id, content, stream_id,
+                )
+                yield self.notify_device_update(user_id, [device_id])
+
+    @defer.inlineCallbacks
+    def on_federation_query_user_devices(self, user_id):
+        stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
+        defer.returnValue({
+            "user_id": user_id,
+            "stream_id": stream_id,
+            "devices": devices,
+        })
+
+    @defer.inlineCallbacks
+    def user_left_room(self, user, room_id):
+        user_id = user.to_string()
+        rooms = yield self.store.get_rooms_for_user(user_id)
+        if not rooms:
+            # We no longer share rooms with this user, so we'll no longer
+            # receive device updates. Mark this in DB.
+            yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
+
 
 def _update_device_from_client_ips(device, client_ips):
     ip = client_ips.get((device["user_id"], device["device_id"]), {})
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index b63a660c06..a16b9def8d 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -73,10 +73,9 @@ class E2eKeysHandler(object):
             if self.is_mine_id(user_id):
                 local_query[user_id] = device_ids
             else:
-                domain = get_domain_from_id(user_id)
-                remote_queries.setdefault(domain, {})[user_id] = device_ids
+                remote_queries[user_id] = device_ids
 
-        # do the queries
+        # Firt get local devices.
         failures = {}
         results = {}
         if local_query:
@@ -85,9 +84,42 @@ class E2eKeysHandler(object):
                 if user_id in local_query:
                     results[user_id] = keys
 
+        # Now attempt to get any remote devices from our local cache.
+        remote_queries_not_in_cache = {}
+        if remote_queries:
+            query_list = []
+            for user_id, device_ids in remote_queries.iteritems():
+                if device_ids:
+                    query_list.extend((user_id, device_id) for device_id in device_ids)
+                else:
+                    query_list.append((user_id, None))
+
+            user_ids_not_in_cache, remote_results = (
+                yield self.store.get_user_devices_from_cache(
+                    query_list
+                )
+            )
+            for user_id, devices in remote_results.iteritems():
+                user_devices = results.setdefault(user_id, {})
+                for device_id, device in devices.iteritems():
+                    keys = device.get("keys", None)
+                    device_display_name = device.get("device_display_name", None)
+                    if keys:
+                        result = dict(keys)
+                        unsigned = result.setdefault("unsigned", {})
+                        if device_display_name:
+                            unsigned["device_display_name"] = device_display_name
+                        user_devices[device_id] = result
+
+            for user_id in user_ids_not_in_cache:
+                domain = get_domain_from_id(user_id)
+                r = remote_queries_not_in_cache.setdefault(domain, {})
+                r[user_id] = remote_queries[user_id]
+
+        # Now fetch any devices that we don't have in our cache
         @defer.inlineCallbacks
         def do_remote_query(destination):
-            destination_query = remote_queries[destination]
+            destination_query = remote_queries_not_in_cache[destination]
             try:
                 limiter = yield get_retry_limiter(
                     destination, self.clock, self.store
@@ -119,7 +151,7 @@ class E2eKeysHandler(object):
 
         yield preserve_context_over_deferred(defer.gatherResults([
             preserve_fn(do_remote_query)(destination)
-            for destination in remote_queries
+            for destination in remote_queries_not_in_cache
         ]))
 
         defer.returnValue({
@@ -259,6 +291,7 @@ class E2eKeysHandler(object):
                 user_id, device_id, time_now,
                 encode_canonical_json(device_keys)
             )
+            yield self.device_handler.notify_device_update(user_id, [device_id])
 
         one_time_keys = keys.get("one_time_keys", None)
         if one_time_keys:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index d3f5892376..996bfd0e23 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1319,7 +1319,6 @@ class FederationHandler(BaseHandler):
 
         event_stream_id, max_stream_id = yield self.store.persist_event(
             event, new_event_context,
-            current_state=state,
         )
 
         defer.returnValue((event_stream_id, max_stream_id))
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 88bd2d572e..7a498af5a2 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -208,8 +208,10 @@ class MessageHandler(BaseHandler):
                     content = builder.content
 
                     try:
-                        content["displayname"] = yield profile.get_displayname(target)
-                        content["avatar_url"] = yield profile.get_avatar_url(target)
+                        if "displayname" not in content:
+                            content["displayname"] = yield profile.get_displayname(target)
+                        if "avatar_url" not in content:
+                            content["avatar_url"] = yield profile.get_avatar_url(target)
                     except Exception as e:
                         logger.info(
                             "Failed to get profile information for %r: %s",
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c880f61685..9199f20817 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -115,6 +115,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
     "invited",  # InvitedSyncResult for each invited room.
     "archived",  # ArchivedSyncResult for each archived room.
     "to_device",  # List of direct messages for the device.
+    "device_lists",  # List of user_ids whose devices have chanegd
 ])):
     __slots__ = []
 
@@ -544,6 +545,10 @@ class SyncHandler(object):
 
         yield self._generate_sync_entry_for_to_device(sync_result_builder)
 
+        device_lists = yield self._generate_sync_entry_for_device_list(
+            sync_result_builder
+        )
+
         defer.returnValue(SyncResult(
             presence=sync_result_builder.presence,
             account_data=sync_result_builder.account_data,
@@ -551,10 +556,33 @@ class SyncHandler(object):
             invited=sync_result_builder.invited,
             archived=sync_result_builder.archived,
             to_device=sync_result_builder.to_device,
+            device_lists=device_lists,
             next_batch=sync_result_builder.now_token,
         ))
 
     @defer.inlineCallbacks
+    def _generate_sync_entry_for_device_list(self, sync_result_builder):
+        user_id = sync_result_builder.sync_config.user.to_string()
+        since_token = sync_result_builder.since_token
+
+        if since_token and since_token.device_list_key:
+            rooms = yield self.store.get_rooms_for_user(user_id)
+            room_ids = set(r.room_id for r in rooms)
+
+            user_ids_changed = set()
+            changed = yield self.store.get_user_whose_devices_changed(
+                since_token.device_list_key
+            )
+            for other_user_id in changed:
+                other_rooms = yield self.store.get_rooms_for_user(other_user_id)
+                if room_ids.intersection(e.room_id for e in other_rooms):
+                    user_ids_changed.add(other_user_id)
+
+            defer.returnValue(user_ids_changed)
+        else:
+            defer.returnValue([])
+
+    @defer.inlineCallbacks
     def _generate_sync_entry_for_to_device(self, sync_result_builder):
         """Generates the portion of the sync response. Populates
         `sync_result_builder` with the result.
diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py
index 4616e9b34a..a30e647474 100644
--- a/synapse/replication/resource.py
+++ b/synapse/replication/resource.py
@@ -46,6 +46,7 @@ STREAM_NAMES = (
     ("to_device",),
     ("public_rooms",),
     ("federation",),
+    ("device_lists",),
 )
 
 
@@ -140,6 +141,7 @@ class ReplicationResource(Resource):
         caches_token = self.store.get_cache_stream_token()
         public_rooms_token = self.store.get_current_public_room_stream_id()
         federation_token = self.federation_sender.get_current_token()
+        device_list_token = self.store.get_device_stream_token()
 
         defer.returnValue(_ReplicationToken(
             room_stream_token,
@@ -155,6 +157,7 @@ class ReplicationResource(Resource):
             int(stream_token.to_device_key),
             int(public_rooms_token),
             int(federation_token),
+            int(device_list_token),
         ))
 
     @request_handler()
@@ -214,6 +217,7 @@ class ReplicationResource(Resource):
         yield self.caches(writer, current_token, limit, request_streams)
         yield self.to_device(writer, current_token, limit, request_streams)
         yield self.public_rooms(writer, current_token, limit, request_streams)
+        yield self.device_lists(writer, current_token, limit, request_streams)
         self.federation(writer, current_token, limit, request_streams, federation_ack)
         self.streams(writer, current_token, request_streams)
 
@@ -495,6 +499,20 @@ class ReplicationResource(Resource):
                 "position", "type", "content",
             ), position=upto_token)
 
+    @defer.inlineCallbacks
+    def device_lists(self, writer, current_token, limit, request_streams):
+        current_position = current_token.device_lists
+
+        device_lists = request_streams.get("device_lists")
+
+        if device_lists is not None and device_lists != current_position:
+            changes = yield self.store.get_all_device_list_changes_for_remotes(
+                device_lists,
+            )
+            writer.write_header_and_rows("device_lists", changes, (
+                "position", "user_id", "destination",
+            ), position=current_position)
+
 
 class _Writer(object):
     """Writes the streams as a JSON object as the response to the request"""
@@ -527,7 +545,7 @@ class _Writer(object):
 class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
     "events", "presence", "typing", "receipts", "account_data", "backfill",
     "push_rules", "pushers", "state", "caches", "to_device", "public_rooms",
-    "federation",
+    "federation", "device_lists",
 ))):
     __slots__ = []
 
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
new file mode 100644
index 0000000000..ca46aa17b6
--- /dev/null
+++ b/synapse/replication/slave/storage/devices.py
@@ -0,0 +1,72 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
+from synapse.storage import DataStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+
+class SlavedDeviceStore(BaseSlavedStore):
+    def __init__(self, db_conn, hs):
+        super(SlavedDeviceStore, self).__init__(db_conn, hs)
+
+        self.hs = hs
+
+        self._device_list_id_gen = SlavedIdTracker(
+            db_conn, "device_lists_stream", "stream_id",
+        )
+        device_list_max = self._device_list_id_gen.get_current_token()
+        self._device_list_stream_cache = StreamChangeCache(
+            "DeviceListStreamChangeCache", device_list_max,
+        )
+        self._device_list_federation_stream_cache = StreamChangeCache(
+            "DeviceListFederationStreamChangeCache", device_list_max,
+        )
+
+    get_device_stream_token = DataStore.get_device_stream_token.__func__
+    get_user_whose_devices_changed = DataStore.get_user_whose_devices_changed.__func__
+    get_devices_by_remote = DataStore.get_devices_by_remote.__func__
+    _get_devices_by_remote_txn = DataStore._get_devices_by_remote_txn.__func__
+    _get_e2e_device_keys_txn = DataStore._get_e2e_device_keys_txn.__func__
+    mark_as_sent_devices_by_remote = DataStore.mark_as_sent_devices_by_remote.__func__
+    _mark_as_sent_devices_by_remote_txn = (
+        DataStore._mark_as_sent_devices_by_remote_txn.__func__
+    )
+
+    def stream_positions(self):
+        result = super(SlavedDeviceStore, self).stream_positions()
+        result["device_lists"] = self._device_list_id_gen.get_current_token()
+        return result
+
+    def process_replication(self, result):
+        stream = result.get("device_lists")
+        if stream:
+            self._device_list_id_gen.advance(int(stream["position"]))
+            for row in stream["rows"]:
+                stream_id = row[0]
+                user_id = row[1]
+                destination = row[2]
+
+                self._device_list_stream_cache.entity_has_changed(
+                    user_id, stream_id
+                )
+
+                if destination:
+                    self._device_list_federation_stream_cache.entity_has_changed(
+                        destination, stream_id
+                    )
+
+        return super(SlavedDeviceStore, self).process_replication(result)
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 64f18bbb3e..b3f3bf7488 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -76,9 +76,6 @@ class SlavedEventStore(BaseSlavedStore):
     get_latest_event_ids_in_room = EventFederationStore.__dict__[
         "get_latest_event_ids_in_room"
     ]
-    _get_current_state_for_key = StateStore.__dict__[
-        "_get_current_state_for_key"
-    ]
     get_invited_rooms_for_user = RoomMemberStore.__dict__[
         "get_invited_rooms_for_user"
     ]
@@ -115,8 +112,6 @@ class SlavedEventStore(BaseSlavedStore):
     )
     get_event = DataStore.get_event.__func__
     get_events = DataStore.get_events.__func__
-    get_current_state = DataStore.get_current_state.__func__
-    get_current_state_for_key = DataStore.get_current_state_for_key.__func__
     get_rooms_for_user_where_membership_is = (
         DataStore.get_rooms_for_user_where_membership_is.__func__
     )
@@ -248,7 +243,6 @@ class SlavedEventStore(BaseSlavedStore):
 
     def invalidate_caches_for_event(self, event, backfilled, reset_state):
         if reset_state:
-            self._get_current_state_for_key.invalidate_all()
             self.get_rooms_for_user.invalidate_all()
             self.get_users_in_room.invalidate((event.room_id,))
 
@@ -289,7 +283,3 @@ class SlavedEventStore(BaseSlavedStore):
         if (not event.internal_metadata.is_invite_from_remote()
                 and event.internal_metadata.is_outlier()):
             return
-
-        self._get_current_state_for_key.invalidate((
-            event.room_id, event.type, event.state_key
-        ))
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 7199ec883a..b3d8001638 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -170,12 +170,16 @@ class SyncRestServlet(RestServlet):
         )
 
         archived = self.encode_archived(
-            sync_result.archived, time_now, requester.access_token_id, filter.event_fields
+            sync_result.archived, time_now, requester.access_token_id,
+            filter.event_fields,
         )
 
         response_content = {
             "account_data": {"events": sync_result.account_data},
             "to_device": {"events": sync_result.to_device},
+            "device_lists": {
+                "changed": list(sync_result.device_lists),
+            },
             "presence": self.encode_presence(
                 sync_result.presence, time_now
             ),
diff --git a/synapse/state.py b/synapse/state.py
index 20aaacf40f..383d32b163 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -429,6 +429,9 @@ def resolve_events(state_sets, state_map_factory):
         dict[(str, str), synapse.events.FrozenEvent] is a map from
         (type, state_key) to event.
     """
+    if len(state_sets) == 1:
+        return state_sets[0]
+
     unconflicted_state, conflicted_state = _seperate(
         state_sets,
     )
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index e8495f1eb9..b9968debe5 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -116,6 +116,9 @@ class DataStore(RoomMemberStore, RoomStore,
         self._public_room_id_gen = StreamIdGenerator(
             db_conn, "public_room_list_stream", "stream_id"
         )
+        self._device_list_id_gen = StreamIdGenerator(
+            db_conn, "device_lists_stream", "stream_id",
+        )
 
         self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
         self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
@@ -210,6 +213,14 @@ class DataStore(RoomMemberStore, RoomStore,
             prefilled_cache=device_outbox_prefill,
         )
 
+        device_list_max = self._device_list_id_gen.get_current_token()
+        self._device_list_stream_cache = StreamChangeCache(
+            "DeviceListStreamChangeCache", device_list_max,
+        )
+        self._device_list_federation_stream_cache = StreamChangeCache(
+            "DeviceListFederationStreamChangeCache", device_list_max,
+        )
+
         cur = LoggingTransaction(
             db_conn.cursor(),
             name="_find_stream_orderings_for_times_txn",
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 963ef999d5..05374682fd 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -387,6 +387,10 @@ class SQLBaseStore(object):
         Args:
             table : string giving the table name
             values : dict of new column names and values for them
+
+        Returns:
+            bool: Whether the row was inserted or not. Only useful when
+            `or_ignore` is True
         """
         try:
             yield self.runInteraction(
@@ -398,6 +402,8 @@ class SQLBaseStore(object):
             # a cursor after we receive an error from the db.
             if not or_ignore:
                 raise
+            defer.returnValue(False)
+        defer.returnValue(True)
 
     @staticmethod
     def _simple_insert_txn(txn, table, values):
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 17920d4480..e68ee50152 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import ujson as json
 
 from twisted.internet import defer
 
@@ -23,27 +24,29 @@ logger = logging.getLogger(__name__)
 
 
 class DeviceStore(SQLBaseStore):
+    def __init__(self, hs):
+        super(DeviceStore, self).__init__(hs)
+
+        self._clock.looping_call(
+            self._prune_old_outbound_device_pokes, 60 * 60 * 1000
+        )
+
     @defer.inlineCallbacks
     def store_device(self, user_id, device_id,
-                     initial_device_display_name,
-                     ignore_if_known=True):
+                     initial_device_display_name):
         """Ensure the given device is known; add it to the store if not
 
         Args:
             user_id (str): id of user associated with the device
             device_id (str): id of device
             initial_device_display_name (str): initial displayname of the
-               device
-            ignore_if_known (bool): ignore integrity errors which mean the
-               device is already known
+               device. Ignored if device exists.
         Returns:
-            defer.Deferred
-        Raises:
-            StoreError: if ignore_if_known is False and the device was already
-               known
+            defer.Deferred: boolean whether the device was inserted or an
+                existing device existed with that ID.
         """
         try:
-            yield self._simple_insert(
+            inserted = yield self._simple_insert(
                 "devices",
                 values={
                     "user_id": user_id,
@@ -51,8 +54,9 @@ class DeviceStore(SQLBaseStore):
                     "display_name": initial_device_display_name
                 },
                 desc="store_device",
-                or_ignore=ignore_if_known,
+                or_ignore=True,
             )
+            defer.returnValue(inserted)
         except Exception as e:
             logger.error("store_device with device_id=%s(%r) user_id=%s(%r)"
                          " display_name=%s(%r) failed: %s",
@@ -139,3 +143,432 @@ class DeviceStore(SQLBaseStore):
         )
 
         defer.returnValue({d["device_id"]: d for d in devices})
+
+    def get_device_list_last_stream_id_for_remote(self, user_id):
+        """Get the last stream_id we got for a user. May be None if we haven't
+        got any information for them.
+        """
+        return self._simple_select_one_onecol(
+            table="device_lists_remote_extremeties",
+            keyvalues={"user_id": user_id},
+            retcol="stream_id",
+            desc="get_device_list_remote_extremity",
+            allow_none=True,
+        )
+
+    def mark_remote_user_device_list_as_unsubscribed(self, user_id):
+        """Mark that we no longer track device lists for remote user.
+        """
+        return self._simple_delete(
+            table="device_lists_remote_extremeties",
+            keyvalues={
+                "user_id": user_id,
+            },
+        )
+
+    def update_remote_device_list_cache_entry(self, user_id, device_id, content,
+                                              stream_id):
+        """Updates a single user's device in the cache.
+        """
+        return self.runInteraction(
+            "update_remote_device_list_cache_entry",
+            self._update_remote_device_list_cache_entry_txn,
+            user_id, device_id, content, stream_id,
+        )
+
+    def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
+                                                   content, stream_id):
+        self._simple_upsert_txn(
+            txn,
+            table="device_lists_remote_cache",
+            keyvalues={
+                "user_id": user_id,
+                "device_id": device_id,
+            },
+            values={
+                "content": json.dumps(content),
+            }
+        )
+
+        self._simple_upsert_txn(
+            txn,
+            table="device_lists_remote_extremeties",
+            keyvalues={
+                "user_id": user_id,
+            },
+            values={
+                "stream_id": stream_id,
+            }
+        )
+
+    def update_remote_device_list_cache(self, user_id, devices, stream_id):
+        """Replace the cache of the remote user's devices.
+        """
+        return self.runInteraction(
+            "update_remote_device_list_cache",
+            self._update_remote_device_list_cache_txn,
+            user_id, devices, stream_id,
+        )
+
+    def _update_remote_device_list_cache_txn(self, txn, user_id, devices,
+                                             stream_id):
+        self._simple_delete_txn(
+            txn,
+            table="device_lists_remote_cache",
+            keyvalues={
+                "user_id": user_id,
+            },
+        )
+
+        self._simple_insert_many_txn(
+            txn,
+            table="device_lists_remote_cache",
+            values=[
+                {
+                    "user_id": user_id,
+                    "device_id": content["device_id"],
+                    "content": json.dumps(content),
+                }
+                for content in devices
+            ]
+        )
+
+        self._simple_upsert_txn(
+            txn,
+            table="device_lists_remote_extremeties",
+            keyvalues={
+                "user_id": user_id,
+            },
+            values={
+                "stream_id": stream_id,
+            }
+        )
+
+    def get_devices_by_remote(self, destination, from_stream_id):
+        """Get stream of updates to send to remote servers
+
+        Returns:
+            (now_stream_id, [ { updates }, .. ])
+        """
+        now_stream_id = self._device_list_id_gen.get_current_token()
+
+        has_changed = self._device_list_federation_stream_cache.has_entity_changed(
+            destination, int(from_stream_id)
+        )
+        if not has_changed:
+            return (now_stream_id, [])
+
+        return self.runInteraction(
+            "get_devices_by_remote", self._get_devices_by_remote_txn,
+            destination, from_stream_id, now_stream_id,
+        )
+
+    def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
+                                   now_stream_id):
+        sql = """
+            SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
+            WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
+            GROUP BY user_id, device_id
+        """
+        txn.execute(
+            sql, (destination, from_stream_id, now_stream_id, False)
+        )
+        rows = txn.fetchall()
+
+        if not rows:
+            return (now_stream_id, [])
+
+        # maps (user_id, device_id) -> stream_id
+        query_map = {(r[0], r[1]): r[2] for r in rows}
+        devices = self._get_e2e_device_keys_txn(
+            txn, query_map.keys(), include_all_devices=True
+        )
+
+        prev_sent_id_sql = """
+            SELECT coalesce(max(stream_id), 0) as stream_id
+            FROM device_lists_outbound_pokes
+            WHERE destination = ? AND user_id = ? AND stream_id <= ?
+        """
+
+        results = []
+        for user_id, user_devices in devices.iteritems():
+            # The prev_id for the first row is always the last row before
+            # `from_stream_id`
+            txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
+            rows = txn.fetchall()
+            prev_id = rows[0][0]
+            for device_id, device in user_devices.iteritems():
+                stream_id = query_map[(user_id, device_id)]
+                result = {
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "prev_id": [prev_id] if prev_id else [],
+                    "stream_id": stream_id,
+                }
+
+                prev_id = stream_id
+
+                key_json = device.get("key_json", None)
+                if key_json:
+                    result["keys"] = json.loads(key_json)
+                device_display_name = device.get("device_display_name", None)
+                if device_display_name:
+                    result["device_display_name"] = device_display_name
+
+                results.append(result)
+
+        return (now_stream_id, results)
+
+    def get_user_devices_from_cache(self, query_list):
+        """Get the devices (and keys if any) for remote users from the cache.
+
+        Args:
+            query_list(list): List of (user_id, device_ids), if device_ids is
+                falsey then return all device ids for that user.
+
+        Returns:
+            (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
+            a set of user_ids and results_map is a mapping of
+            user_id -> device_id -> device_info
+        """
+        return self.runInteraction(
+            "get_user_devices_from_cache", self._get_user_devices_from_cache_txn,
+            query_list,
+        )
+
+    def _get_user_devices_from_cache_txn(self, txn, query_list):
+        user_ids = {user_id for user_id, _ in query_list}
+
+        user_ids_in_cache = set()
+        for user_id in user_ids:
+            stream_ids = self._simple_select_onecol_txn(
+                txn,
+                table="device_lists_remote_extremeties",
+                keyvalues={
+                    "user_id": user_id,
+                },
+                retcol="stream_id",
+            )
+            if stream_ids:
+                user_ids_in_cache.add(user_id)
+
+        user_ids_not_in_cache = user_ids - user_ids_in_cache
+
+        results = {}
+        for user_id, device_id in query_list:
+            if user_id not in user_ids_in_cache:
+                continue
+
+            if device_id:
+                content = self._simple_select_one_onecol_txn(
+                    txn,
+                    table="device_lists_remote_cache",
+                    keyvalues={
+                        "user_id": user_id,
+                        "device_id": device_id,
+                    },
+                    retcol="content",
+                )
+                results.setdefault(user_id, {})[device_id] = json.loads(content)
+            else:
+                devices = self._simple_select_list_txn(
+                    txn,
+                    table="device_lists_remote_cache",
+                    keyvalues={
+                        "user_id": user_id,
+                    },
+                    retcols=("device_id", "content"),
+                )
+                results[user_id] = {
+                    device["device_id"]: json.loads(device["content"])
+                    for device in devices
+                }
+                user_ids_in_cache.discard(user_id)
+
+        return user_ids_not_in_cache, results
+
+    def get_devices_with_keys_by_user(self, user_id):
+        """Get all devices (with any device keys) for a user
+
+        Returns:
+            (stream_id, devices)
+        """
+        return self.runInteraction(
+            "get_devices_with_keys_by_user",
+            self._get_devices_with_keys_by_user_txn, user_id,
+        )
+
+    def _get_devices_with_keys_by_user_txn(self, txn, user_id):
+        now_stream_id = self._device_list_id_gen.get_current_token()
+
+        devices = self._get_e2e_device_keys_txn(
+            txn, [(user_id, None)], include_all_devices=True
+        )
+
+        if devices:
+            user_devices = devices[user_id]
+            results = []
+            for device_id, device in user_devices.iteritems():
+                result = {
+                    "device_id": device_id,
+                }
+
+                key_json = device.get("key_json", None)
+                if key_json:
+                    result["keys"] = json.loads(key_json)
+                device_display_name = device.get("device_display_name", None)
+                if device_display_name:
+                    result["device_display_name"] = device_display_name
+
+                results.append(result)
+
+            return now_stream_id, results
+
+        return now_stream_id, []
+
+    def mark_as_sent_devices_by_remote(self, destination, stream_id):
+        """Mark that updates have successfully been sent to the destination.
+        """
+        return self.runInteraction(
+            "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
+            destination, stream_id,
+        )
+
+    def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
+        sql = """
+            DELETE FROM device_lists_outbound_pokes
+            WHERE destination = ? AND stream_id < (
+                SELECT coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes
+                WHERE destination = ? AND stream_id <= ?
+            )
+        """
+        txn.execute(sql, (destination, destination, stream_id,))
+
+        sql = """
+            UPDATE device_lists_outbound_pokes SET sent = ?
+            WHERE destination = ? AND stream_id <= ?
+        """
+        txn.execute(sql, (True, destination, stream_id,))
+
+    @defer.inlineCallbacks
+    def get_user_whose_devices_changed(self, from_key):
+        """Get set of users whose devices have changed since `from_key`.
+        """
+        from_key = int(from_key)
+        changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
+        if changed is not None:
+            defer.returnValue(set(changed))
+
+        sql = """
+            SELECT user_id FROM device_lists_stream WHERE stream_id > ?
+        """
+        rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
+        defer.returnValue(set(row["user_id"] for row in rows))
+
+    def get_all_device_list_changes_for_remotes(self, from_key):
+        """Return a list of `(stream_id, user_id, destination)` which is the
+        combined list of changes to devices, and which destinations need to be
+        poked. `destination` may be None if no destinations need to be poked.
+        """
+        sql = """
+            SELECT stream_id, user_id, destination FROM device_lists_stream
+            LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
+            WHERE stream_id > ?
+        """
+        return self._execute(
+            "get_users_and_hosts_device_list", None,
+            sql, from_key,
+        )
+
+    @defer.inlineCallbacks
+    def add_device_change_to_streams(self, user_id, device_ids, hosts):
+        """Persist that a user's devices have been updated, and which hosts
+        (if any) should be poked.
+        """
+        with self._device_list_id_gen.get_next() as stream_id:
+            yield self.runInteraction(
+                "add_device_change_to_streams", self._add_device_change_txn,
+                user_id, device_ids, hosts, stream_id,
+            )
+        defer.returnValue(stream_id)
+
+    def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
+        now = self._clock.time_msec()
+
+        txn.call_after(
+            self._device_list_stream_cache.entity_has_changed,
+            user_id, stream_id,
+        )
+        for host in hosts:
+            txn.call_after(
+                self._device_list_federation_stream_cache.entity_has_changed,
+                host, stream_id,
+            )
+
+        self._simple_insert_many_txn(
+            txn,
+            table="device_lists_stream",
+            values=[
+                {
+                    "stream_id": stream_id,
+                    "user_id": user_id,
+                    "device_id": device_id,
+                }
+                for device_id in device_ids
+            ]
+        )
+
+        self._simple_insert_many_txn(
+            txn,
+            table="device_lists_outbound_pokes",
+            values=[
+                {
+                    "destination": destination,
+                    "stream_id": stream_id,
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "sent": False,
+                    "ts": now,
+                }
+                for destination in hosts
+                for device_id in device_ids
+            ]
+        )
+
+    def get_device_stream_token(self):
+        return self._device_list_id_gen.get_current_token()
+
+    def _prune_old_outbound_device_pokes(self):
+        """Delete old entries out of the device_lists_outbound_pokes to ensure
+        that we don't fill up due to dead servers. We keep one entry per
+        (destination, user_id) tuple to ensure that the prev_ids remain correct
+        if the server does come back.
+        """
+        now = self._clock.time_msec()
+
+        def _prune_txn(txn):
+            select_sql = """
+                SELECT destination, user_id, max(stream_id) as stream_id
+                FROM device_lists_outbound_pokes
+                GROUP BY destination, user_id
+            """
+
+            txn.execute(select_sql)
+            rows = txn.fetchall()
+
+            delete_sql = """
+                DELETE FROM device_lists_outbound_pokes
+                WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ?
+            """
+
+            txn.executemany(
+                delete_sql,
+                (
+                    (now, row["destination"], row["user_id"], row["stream_id"])
+                    for row in rows
+                )
+            )
+
+        return self.runInteraction(
+            "_prune_old_outbound_device_pokes", _prune_txn
+        )
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 385d607056..85763f7ceb 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -12,9 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import collections
-
-import twisted.internet.defer
+from twisted.internet import defer
 
 from ._base import SQLBaseStore
 
@@ -33,10 +31,12 @@ class EndToEndKeyStore(SQLBaseStore):
             }
         )
 
-    def get_e2e_device_keys(self, query_list):
+    def get_e2e_device_keys(self, query_list, include_all_devices=False):
         """Fetch a list of device keys.
         Args:
             query_list(list): List of pairs of user_ids and device_ids.
+            include_all_devices (bool): whether to include entries for devices
+                that don't have device keys
         Returns:
             Dict mapping from user-id to dict mapping from device_id to
             dict containing "key_json", "device_display_name".
@@ -45,41 +45,42 @@ class EndToEndKeyStore(SQLBaseStore):
             return {}
 
         return self.runInteraction(
-            "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list
+            "get_e2e_device_keys", self._get_e2e_device_keys_txn,
+            query_list, include_all_devices,
         )
 
-    def _get_e2e_device_keys_txn(self, txn, query_list):
+    def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
         query_clauses = []
         query_params = []
 
         for (user_id, device_id) in query_list:
-            query_clause = "k.user_id = ?"
+            query_clause = "user_id = ?"
             query_params.append(user_id)
 
             if device_id:
-                query_clause += " AND k.device_id = ?"
+                query_clause += " AND device_id = ?"
                 query_params.append(device_id)
 
             query_clauses.append(query_clause)
 
         sql = (
-            "SELECT k.user_id, k.device_id, "
+            "SELECT user_id, device_id, "
             "    d.display_name AS device_display_name, "
             "    k.key_json"
-            " FROM e2e_device_keys_json k"
-            "    LEFT JOIN devices d ON d.user_id = k.user_id"
-            "      AND d.device_id = k.device_id"
+            " FROM devices d"
+            "    %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
             " WHERE %s"
         ) % (
+            "LEFT" if include_all_devices else "INNER",
             " OR ".join("(" + q + ")" for q in query_clauses)
         )
 
         txn.execute(sql, query_params)
         rows = self.cursor_to_dict(txn)
 
-        result = collections.defaultdict(dict)
+        result = {}
         for row in rows:
-            result[row["user_id"]][row["device_id"]] = row
+            result.setdefault(row["user_id"], {})[row["device_id"]] = row
 
         return result
 
@@ -152,7 +153,7 @@ class EndToEndKeyStore(SQLBaseStore):
             "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
         )
 
-    @twisted.internet.defer.inlineCallbacks
+    @defer.inlineCallbacks
     def delete_e2e_keys_by_device(self, user_id, device_id):
         yield self._simple_delete(
             table="e2e_device_keys_json",
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 53feaa1960..f0aa2193fb 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -235,80 +235,21 @@ class EventFederationStore(SQLBaseStore):
             ],
         )
 
-        self._update_extremeties(txn, events)
+        self._update_backward_extremeties(txn, events)
 
-    def _update_extremeties(self, txn, events):
-        """Updates the event_*_extremities tables based on the new/updated
+    def _update_backward_extremeties(self, txn, events):
+        """Updates the event_backward_extremities tables based on the new/updated
         events being persisted.
 
         This is called for new events *and* for events that were outliers, but
-        are are now being persisted as non-outliers.
+        are now being persisted as non-outliers.
+
+        Forward extremities are handled when we first start persisting the events.
         """
         events_by_room = {}
         for ev in events:
             events_by_room.setdefault(ev.room_id, []).append(ev)
 
-        for room_id, room_events in events_by_room.items():
-            prevs = [
-                e_id for ev in room_events for e_id, _ in ev.prev_events
-                if not ev.internal_metadata.is_outlier()
-            ]
-            if prevs:
-                txn.execute(
-                    "DELETE FROM event_forward_extremities"
-                    " WHERE room_id = ?"
-                    " AND event_id in (%s)" % (
-                        ",".join(["?"] * len(prevs)),
-                    ),
-                    [room_id] + prevs,
-                )
-
-        query = (
-            "INSERT INTO event_forward_extremities (event_id, room_id)"
-            " SELECT ?, ? WHERE NOT EXISTS ("
-            " SELECT 1 FROM event_edges WHERE prev_event_id = ?"
-            " )"
-        )
-
-        txn.executemany(
-            query,
-            [
-                (ev.event_id, ev.room_id, ev.event_id) for ev in events
-                if not ev.internal_metadata.is_outlier()
-            ]
-        )
-
-        # We now insert into stream_ordering_to_exterm a mapping from room_id,
-        # new stream_ordering to new forward extremeties in the room.
-        # This allows us to later efficiently look up the forward extremeties
-        # for a room before a given stream_ordering
-        max_stream_ord = max(
-            ev.internal_metadata.stream_ordering for ev in events
-        )
-        new_extrem = {}
-        for room_id in events_by_room:
-            event_ids = self._simple_select_onecol_txn(
-                txn,
-                table="event_forward_extremities",
-                keyvalues={"room_id": room_id},
-                retcol="event_id",
-            )
-            new_extrem[room_id] = event_ids
-
-        self._simple_insert_many_txn(
-            txn,
-            table="stream_ordering_to_exterm",
-            values=[
-                {
-                    "room_id": room_id,
-                    "event_id": event_id,
-                    "stream_ordering": max_stream_ord,
-                }
-                for room_id, extrem_evs in new_extrem.items()
-                for event_id in extrem_evs
-            ]
-        )
-
         query = (
             "INSERT INTO event_backward_extremities (event_id, room_id)"
             " SELECT ?, ? WHERE NOT EXISTS ("
@@ -339,11 +280,6 @@ class EventFederationStore(SQLBaseStore):
             ]
         )
 
-        for room_id in events_by_room:
-            txn.call_after(
-                self.get_latest_event_ids_in_room.invalidate, (room_id,)
-            )
-
     def get_forward_extremeties_for_room(self, room_id, stream_ordering):
         # We want to make the cache more effective, so we clamp to the last
         # change before the given ordering.
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 0bb6420b4f..910a37ae61 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from ._base import SQLBaseStore, _RollbackButIsFineException
+from ._base import SQLBaseStore
 
 from twisted.internet import defer, reactor
 
@@ -27,6 +27,7 @@ from synapse.util.logutils import log_function
 from synapse.util.metrics import Measure
 from synapse.api.constants import EventTypes
 from synapse.api.errors import SynapseError
+from synapse.state import resolve_events
 
 from canonicaljson import encode_canonical_json
 from collections import deque, namedtuple, OrderedDict
@@ -71,22 +72,19 @@ class _EventPeristenceQueue(object):
     """
 
     _EventPersistQueueItem = namedtuple("_EventPersistQueueItem", (
-        "events_and_contexts", "current_state", "backfilled", "deferred",
+        "events_and_contexts", "backfilled", "deferred",
     ))
 
     def __init__(self):
         self._event_persist_queues = {}
         self._currently_persisting_rooms = set()
 
-    def add_to_queue(self, room_id, events_and_contexts, backfilled, current_state):
+    def add_to_queue(self, room_id, events_and_contexts, backfilled):
         """Add events to the queue, with the given persist_event options.
         """
         queue = self._event_persist_queues.setdefault(room_id, deque())
         if queue:
             end_item = queue[-1]
-            if end_item.current_state or current_state:
-                # We perist events with current_state set to True one at a time
-                pass
             if end_item.backfilled == backfilled:
                 end_item.events_and_contexts.extend(events_and_contexts)
                 return end_item.deferred.observe()
@@ -96,7 +94,6 @@ class _EventPeristenceQueue(object):
         queue.append(self._EventPersistQueueItem(
             events_and_contexts=events_and_contexts,
             backfilled=backfilled,
-            current_state=current_state,
             deferred=deferred,
         ))
 
@@ -216,7 +213,6 @@ class EventsStore(SQLBaseStore):
             d = preserve_fn(self._event_persist_queue.add_to_queue)(
                 room_id, evs_ctxs,
                 backfilled=backfilled,
-                current_state=None,
             )
             deferreds.append(d)
 
@@ -229,11 +225,10 @@ class EventsStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     @log_function
-    def persist_event(self, event, context, current_state=None, backfilled=False):
+    def persist_event(self, event, context, backfilled=False):
         deferred = self._event_persist_queue.add_to_queue(
             event.room_id, [(event, context)],
             backfilled=backfilled,
-            current_state=current_state,
         )
 
         self._maybe_start_persisting(event.room_id)
@@ -246,21 +241,10 @@ class EventsStore(SQLBaseStore):
     def _maybe_start_persisting(self, room_id):
         @defer.inlineCallbacks
         def persisting_queue(item):
-            if item.current_state:
-                for event, context in item.events_and_contexts:
-                    # There should only ever be one item in
-                    # events_and_contexts when current_state is
-                    # not None
-                    yield self._persist_event(
-                        event, context,
-                        current_state=item.current_state,
-                        backfilled=item.backfilled,
-                    )
-            else:
-                yield self._persist_events(
-                    item.events_and_contexts,
-                    backfilled=item.backfilled,
-                )
+            yield self._persist_events(
+                item.events_and_contexts,
+                backfilled=item.backfilled,
+            )
 
         self._event_persist_queue.handle_queue(room_id, persisting_queue)
 
@@ -294,35 +278,183 @@ class EventsStore(SQLBaseStore):
             for chunk in chunks:
                 # We can't easily parallelize these since different chunks
                 # might contain the same event. :(
+
+                # NB: Assumes that we are only persisting events for one room
+                # at a time.
+                new_forward_extremeties = {}
+                current_state_for_room = {}
+                if not backfilled:
+                    with Measure(self._clock, "_calculate_state_and_extrem"):
+                        # Work out the new "current state" for each room.
+                        # We do this by working out what the new extremities are and then
+                        # calculating the state from that.
+                        events_by_room = {}
+                        for event, context in chunk:
+                            events_by_room.setdefault(event.room_id, []).append(
+                                (event, context)
+                            )
+
+                        for room_id, ev_ctx_rm in events_by_room.items():
+                            # Work out new extremities by recursively adding and removing
+                            # the new events.
+                            latest_event_ids = yield self.get_latest_event_ids_in_room(
+                                room_id
+                            )
+                            new_latest_event_ids = yield self._calculate_new_extremeties(
+                                room_id, [ev for ev, _ in ev_ctx_rm]
+                            )
+
+                            if new_latest_event_ids == set(latest_event_ids):
+                                # No change in extremities, so no change in state
+                                continue
+
+                            new_forward_extremeties[room_id] = new_latest_event_ids
+
+                            state = yield self._calculate_state_delta(
+                                room_id, ev_ctx_rm, new_latest_event_ids
+                            )
+                            if state:
+                                current_state_for_room[room_id] = state
+
                 yield self.runInteraction(
                     "persist_events",
                     self._persist_events_txn,
                     events_and_contexts=chunk,
                     backfilled=backfilled,
                     delete_existing=delete_existing,
+                    current_state_for_room=current_state_for_room,
+                    new_forward_extremeties=new_forward_extremeties,
                 )
                 persist_event_counter.inc_by(len(chunk))
 
-    @_retry_on_integrity_error
     @defer.inlineCallbacks
-    @log_function
-    def _persist_event(self, event, context, current_state=None, backfilled=False,
-                       delete_existing=False):
-        try:
-            with self._stream_id_gen.get_next() as stream_ordering:
-                event.internal_metadata.stream_ordering = stream_ordering
-                yield self.runInteraction(
-                    "persist_event",
-                    self._persist_event_txn,
-                    event=event,
-                    context=context,
-                    current_state=current_state,
-                    backfilled=backfilled,
-                    delete_existing=delete_existing,
-                )
-                persist_event_counter.inc()
-        except _RollbackButIsFineException:
-            pass
+    def _calculate_new_extremeties(self, room_id, events):
+        """Calculates the new forward extremeties for a room given events to
+        persist.
+
+        Assumes that we are only persisting events for one room at a time.
+        """
+        latest_event_ids = yield self.get_latest_event_ids_in_room(
+            room_id
+        )
+        new_latest_event_ids = set(latest_event_ids)
+        # First, add all the new events to the list
+        new_latest_event_ids.update(
+            event.event_id for event in events
+            if not event.internal_metadata.is_outlier()
+        )
+        # Now remove all events that are referenced by the to-be-added events
+        new_latest_event_ids.difference_update(
+            e_id
+            for event in events
+            for e_id, _ in event.prev_events
+            if not event.internal_metadata.is_outlier()
+        )
+
+        # And finally remove any events that are referenced by previously added
+        # events.
+        rows = yield self._simple_select_many_batch(
+            table="event_edges",
+            column="prev_event_id",
+            iterable=list(new_latest_event_ids),
+            retcols=["prev_event_id"],
+            keyvalues={
+                "room_id": room_id,
+                "is_state": False,
+            },
+            desc="_calculate_new_extremeties",
+        )
+
+        new_latest_event_ids.difference_update(
+            row["prev_event_id"] for row in rows
+        )
+
+        defer.returnValue(new_latest_event_ids)
+
+    @defer.inlineCallbacks
+    def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
+        """Calculate the new state deltas for a room.
+
+        Assumes that we are only persisting events for one room at a time.
+
+        Returns:
+            2-tuple (to_delete, to_insert) where both are state dicts, i.e.
+            (type, state_key) -> event_id. `to_delete` are the entries to
+            first be deleted from current_state_events, `to_insert` are entries
+            to insert.
+            May return None if there are no changes to be applied.
+        """
+        # Now we need to work out the different state sets for
+        # each state extremities
+        state_sets = []
+        missing_event_ids = []
+        was_updated = False
+        for event_id in new_latest_event_ids:
+            # First search in the list of new events we're adding,
+            # and then use the current state from that
+            for ev, ctx in events_context:
+                if event_id == ev.event_id:
+                    if ctx.current_state_ids is None:
+                        raise Exception("Unknown current state")
+                    state_sets.append(ctx.current_state_ids)
+                    if ctx.delta_ids or hasattr(ev, "state_key"):
+                        was_updated = True
+                    break
+            else:
+                # If we couldn't find it, then we'll need to pull
+                # the state from the database
+                was_updated = True
+                missing_event_ids.append(event_id)
+
+        if missing_event_ids:
+            # Now pull out the state for any missing events from DB
+            event_to_groups = yield self._get_state_group_for_events(
+                missing_event_ids,
+            )
+
+            groups = set(event_to_groups.values())
+            group_to_state = yield self._get_state_for_groups(groups)
+
+            state_sets.extend(group_to_state.values())
+
+        if not new_latest_event_ids:
+            current_state = {}
+        elif was_updated:
+            current_state = yield resolve_events(
+                state_sets,
+                state_map_factory=lambda ev_ids: self.get_events(
+                    ev_ids, get_prev_content=False, check_redacted=False,
+                ),
+            )
+        else:
+            return
+
+        existing_state_rows = yield self._simple_select_list(
+            table="current_state_events",
+            keyvalues={"room_id": room_id},
+            retcols=["event_id", "type", "state_key"],
+            desc="_calculate_state_delta",
+        )
+
+        existing_events = set(row["event_id"] for row in existing_state_rows)
+        new_events = set(ev_id for ev_id in current_state.itervalues())
+        changed_events = existing_events ^ new_events
+
+        if not changed_events:
+            return
+
+        to_delete = {
+            (row["type"], row["state_key"]): row["event_id"]
+            for row in existing_state_rows
+            if row["event_id"] in changed_events
+        }
+        events_to_insert = (new_events - existing_events)
+        to_insert = {
+            key: ev_id for key, ev_id in current_state.iteritems()
+            if ev_id in events_to_insert
+        }
+
+        defer.returnValue((to_delete, to_insert))
 
     @defer.inlineCallbacks
     def get_event(self, event_id, check_redacted=True,
@@ -381,52 +513,9 @@ class EventsStore(SQLBaseStore):
         defer.returnValue({e.event_id: e for e in events})
 
     @log_function
-    def _persist_event_txn(self, txn, event, context, current_state, backfilled=False,
-                           delete_existing=False):
-        # We purposefully do this first since if we include a `current_state`
-        # key, we *want* to update the `current_state_events` table
-        if current_state:
-            txn.call_after(self._get_current_state_for_key.invalidate_all)
-            txn.call_after(self.get_rooms_for_user.invalidate_all)
-            txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
-
-            # Add an entry to the current_state_resets table to record the point
-            # where we clobbered the current state
-            stream_order = event.internal_metadata.stream_ordering
-            self._simple_insert_txn(
-                txn,
-                table="current_state_resets",
-                values={"event_stream_ordering": stream_order}
-            )
-
-            self._simple_delete_txn(
-                txn,
-                table="current_state_events",
-                keyvalues={"room_id": event.room_id},
-            )
-
-            for s in current_state:
-                self._simple_insert_txn(
-                    txn,
-                    "current_state_events",
-                    {
-                        "event_id": s.event_id,
-                        "room_id": s.room_id,
-                        "type": s.type,
-                        "state_key": s.state_key,
-                    }
-                )
-
-        return self._persist_events_txn(
-            txn,
-            [(event, context)],
-            backfilled=backfilled,
-            delete_existing=delete_existing,
-        )
-
-    @log_function
     def _persist_events_txn(self, txn, events_and_contexts, backfilled,
-                            delete_existing=False):
+                            delete_existing=False, current_state_for_room={},
+                            new_forward_extremeties={}):
         """Insert some number of room events into the necessary database tables.
 
         Rejected events are only inserted into the events table, the events_json table,
@@ -436,6 +525,97 @@ class EventsStore(SQLBaseStore):
         If delete_existing is True then existing events will be purged from the
         database before insertion. This is useful when retrying due to IntegrityError.
         """
+        max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
+        for room_id, current_state_tuple in current_state_for_room.iteritems():
+                to_delete, to_insert = current_state_tuple
+                txn.executemany(
+                    "DELETE FROM current_state_events WHERE event_id = ?",
+                    [(ev_id,) for ev_id in to_delete.itervalues()],
+                )
+
+                self._simple_insert_many_txn(
+                    txn,
+                    table="current_state_events",
+                    values=[
+                        {
+                            "event_id": ev_id,
+                            "room_id": room_id,
+                            "type": key[0],
+                            "state_key": key[1],
+                        }
+                        for key, ev_id in to_insert.iteritems()
+                    ],
+                )
+
+                # Invalidate the various caches
+
+                # Figure out the changes of membership to invalidate the
+                # `get_rooms_for_user` cache.
+                # We find out which membership events we may have deleted
+                # and which we have added, then we invlidate the caches for all
+                # those users.
+                members_changed = set(
+                    state_key for ev_type, state_key in to_delete.iterkeys()
+                    if ev_type == EventTypes.Member
+                )
+                members_changed.update(
+                    state_key for ev_type, state_key in to_insert.iterkeys()
+                    if ev_type == EventTypes.Member
+                )
+
+                for member in members_changed:
+                    txn.call_after(self.get_rooms_for_user.invalidate, (member,))
+
+                txn.call_after(self.get_users_in_room.invalidate, (room_id,))
+
+                # Add an entry to the current_state_resets table to record the point
+                # where we clobbered the current state
+                self._simple_insert_txn(
+                    txn,
+                    table="current_state_resets",
+                    values={"event_stream_ordering": max_stream_order}
+                )
+
+        for room_id, new_extrem in new_forward_extremeties.items():
+            self._simple_delete_txn(
+                txn,
+                table="event_forward_extremities",
+                keyvalues={"room_id": room_id},
+            )
+            txn.call_after(
+                self.get_latest_event_ids_in_room.invalidate, (room_id,)
+            )
+
+        self._simple_insert_many_txn(
+            txn,
+            table="event_forward_extremities",
+            values=[
+                {
+                    "event_id": ev_id,
+                    "room_id": room_id,
+                }
+                for room_id, new_extrem in new_forward_extremeties.items()
+                for ev_id in new_extrem
+            ],
+        )
+        # We now insert into stream_ordering_to_exterm a mapping from room_id,
+        # new stream_ordering to new forward extremeties in the room.
+        # This allows us to later efficiently look up the forward extremeties
+        # for a room before a given stream_ordering
+        self._simple_insert_many_txn(
+            txn,
+            table="stream_ordering_to_exterm",
+            values=[
+                {
+                    "room_id": room_id,
+                    "event_id": event_id,
+                    "stream_ordering": max_stream_order,
+                }
+                for room_id, new_extrem in new_forward_extremeties.items()
+                for event_id in new_extrem
+            ]
+        )
+
         # Ensure that we don't have the same event twice.
         # Pick the earliest non-outlier if there is one, else the earliest one.
         new_events_and_contexts = OrderedDict()
@@ -550,7 +730,7 @@ class EventsStore(SQLBaseStore):
 
                 # Update the event_backward_extremities table now that this
                 # event isn't an outlier any more.
-                self._update_extremeties(txn, [event])
+                self._update_backward_extremeties(txn, [event])
 
         events_and_contexts = [
             ec for ec in events_and_contexts if ec[0] not in to_remove
@@ -804,29 +984,6 @@ class EventsStore(SQLBaseStore):
             # to update the current state table
             return
 
-        for event, _ in state_events_and_contexts:
-            if event.internal_metadata.is_outlier():
-                # Outlier events shouldn't clobber the current state.
-                continue
-
-            txn.call_after(
-                self._get_current_state_for_key.invalidate,
-                (event.room_id, event.type, event.state_key,)
-            )
-
-            self._simple_upsert_txn(
-                txn,
-                "current_state_events",
-                keyvalues={
-                    "room_id": event.room_id,
-                    "type": event.type,
-                    "state_key": event.state_key,
-                },
-                values={
-                    "event_id": event.event_id,
-                }
-            )
-
         return
 
     def _add_to_cache(self, txn, events_and_contexts):
diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/schema/delta/40/device_list_streams.sql
new file mode 100644
index 0000000000..54841b3843
--- /dev/null
+++ b/synapse/storage/schema/delta/40/device_list_streams.sql
@@ -0,0 +1,59 @@
+/* Copyright 2017 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Cache of remote devices.
+CREATE TABLE device_lists_remote_cache (
+    user_id TEXT NOT NULL,
+    device_id TEXT NOT NULL,
+    content TEXT NOT NULL
+);
+
+CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id);
+
+
+-- The last update we got for a user. Empty if we're not receiving updates for
+-- that user.
+CREATE TABLE device_lists_remote_extremeties (
+    user_id TEXT NOT NULL,
+    stream_id TEXT NOT NULL
+);
+
+CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id);
+
+
+-- Stream of device lists updates. Includes both local and remotes
+CREATE TABLE device_lists_stream (
+    stream_id BIGINT NOT NULL,
+    user_id TEXT NOT NULL,
+    device_id TEXT NOT NULL
+);
+
+CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id);
+
+
+-- The stream of updates to send to other servers. We keep at least one row
+-- per user that was sent so that the prev_id for any new updates can be
+-- calculated
+CREATE TABLE device_lists_outbound_pokes (
+    destination TEXT NOT NULL,
+    stream_id BIGINT NOT NULL,
+    user_id TEXT NOT NULL,
+    device_id TEXT NOT NULL,
+    sent BOOLEAN NOT NULL,
+    ts BIGINT NOT NULL  -- So that in future we can clear out pokes to dead servers
+);
+
+CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id);
+CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id);
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 7d34dd03bf..d1d653327c 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -232,58 +232,6 @@ class StateStore(SQLBaseStore):
 
             return count
 
-    @defer.inlineCallbacks
-    def get_current_state(self, room_id, event_type=None, state_key=""):
-        if event_type and state_key is not None:
-            result = yield self.get_current_state_for_key(
-                room_id, event_type, state_key
-            )
-            defer.returnValue(result)
-
-        def f(txn):
-            sql = (
-                "SELECT event_id FROM current_state_events"
-                " WHERE room_id = ? "
-            )
-
-            if event_type and state_key is not None:
-                sql += " AND type = ? AND state_key = ? "
-                args = (room_id, event_type, state_key)
-            elif event_type:
-                sql += " AND type = ?"
-                args = (room_id, event_type)
-            else:
-                args = (room_id, )
-
-            txn.execute(sql, args)
-            results = txn.fetchall()
-
-            return [r[0] for r in results]
-
-        event_ids = yield self.runInteraction("get_current_state", f)
-        events = yield self._get_events(event_ids, get_prev_content=False)
-        defer.returnValue(events)
-
-    @defer.inlineCallbacks
-    def get_current_state_for_key(self, room_id, event_type, state_key):
-        event_ids = yield self._get_current_state_for_key(room_id, event_type, state_key)
-        events = yield self._get_events(event_ids, get_prev_content=False)
-        defer.returnValue(events)
-
-    @cached(num_args=3)
-    def _get_current_state_for_key(self, room_id, event_type, state_key):
-        def f(txn):
-            sql = (
-                "SELECT event_id FROM current_state_events"
-                " WHERE room_id = ? AND type = ? AND state_key = ?"
-            )
-
-            args = (room_id, event_type, state_key)
-            txn.execute(sql, args)
-            results = txn.fetchall()
-            return [r[0] for r in results]
-        return self.runInteraction("get_current_state_for_key", f)
-
     @cached(num_args=2, max_entries=100000, iterable=True)
     def _get_state_group_from_group(self, group, types):
         raise NotImplementedError()
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 4d44c3d4ca..91a59b0bae 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -44,6 +44,7 @@ class EventSources(object):
     def get_current_token(self):
         push_rules_key, _ = self.store.get_push_rules_stream_token()
         to_device_key = self.store.get_to_device_stream_token()
+        device_list_key = self.store.get_device_stream_token()
 
         token = StreamToken(
             room_key=(
@@ -63,6 +64,7 @@ class EventSources(object):
             ),
             push_rules_key=push_rules_key,
             to_device_key=to_device_key,
+            device_list_key=device_list_key,
         )
         defer.returnValue(token)
 
@@ -70,6 +72,7 @@ class EventSources(object):
     def get_current_token_for_room(self, room_id):
         push_rules_key, _ = self.store.get_push_rules_stream_token()
         to_device_key = self.store.get_to_device_stream_token()
+        device_list_key = self.store.get_device_stream_token()
 
         token = StreamToken(
             room_key=(
@@ -89,5 +92,6 @@ class EventSources(object):
             ),
             push_rules_key=push_rules_key,
             to_device_key=to_device_key,
+            device_list_key=device_list_key,
         )
         defer.returnValue(token)
diff --git a/synapse/types.py b/synapse/types.py
index 3a3ab21d17..9666f9d73f 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -158,6 +158,7 @@ class StreamToken(
         "account_data_key",
         "push_rules_key",
         "to_device_key",
+        "device_list_key",
     ))
 ):
     _SEPARATOR = "_"
@@ -195,6 +196,7 @@ class StreamToken(
             or (int(other.account_data_key) < int(self.account_data_key))
             or (int(other.push_rules_key) < int(self.push_rules_key))
             or (int(other.to_device_key) < int(self.to_device_key))
+            or (int(other.device_list_key) < int(self.device_list_key))
         )
 
     def copy_and_advance(self, key, new_value):
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 85a970a6c9..2eaaa8253c 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -35,51 +35,51 @@ class DeviceTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def setUp(self):
-        hs = yield utils.setup_test_homeserver(handlers=None)
-        self.handler = synapse.handlers.device.DeviceHandler(hs)
+        hs = yield utils.setup_test_homeserver()
+        self.handler = hs.get_device_handler()
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
 
     @defer.inlineCallbacks
     def test_device_is_created_if_doesnt_exist(self):
         res = yield self.handler.check_device_registered(
-            user_id="boris",
+            user_id="@boris:foo",
             device_id="fco",
             initial_device_display_name="display name"
         )
         self.assertEqual(res, "fco")
 
-        dev = yield self.handler.store.get_device("boris", "fco")
+        dev = yield self.handler.store.get_device("@boris:foo", "fco")
         self.assertEqual(dev["display_name"], "display name")
 
     @defer.inlineCallbacks
     def test_device_is_preserved_if_exists(self):
         res1 = yield self.handler.check_device_registered(
-            user_id="boris",
+            user_id="@boris:foo",
             device_id="fco",
             initial_device_display_name="display name"
         )
         self.assertEqual(res1, "fco")
 
         res2 = yield self.handler.check_device_registered(
-            user_id="boris",
+            user_id="@boris:foo",
             device_id="fco",
             initial_device_display_name="new display name"
         )
         self.assertEqual(res2, "fco")
 
-        dev = yield self.handler.store.get_device("boris", "fco")
+        dev = yield self.handler.store.get_device("@boris:foo", "fco")
         self.assertEqual(dev["display_name"], "display name")
 
     @defer.inlineCallbacks
     def test_device_id_is_made_up_if_unspecified(self):
         device_id = yield self.handler.check_device_registered(
-            user_id="theresa",
+            user_id="@theresa:foo",
             device_id=None,
             initial_device_display_name="display"
         )
 
-        dev = yield self.handler.store.get_device("theresa", device_id)
+        dev = yield self.handler.store.get_device("@theresa:foo", device_id)
         self.assertEqual(dev["display_name"], "display")
 
     @defer.inlineCallbacks
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 5d602c1531..ceb9aa5765 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -37,6 +37,7 @@ class DirectoryTestCase(unittest.TestCase):
     def setUp(self):
         self.mock_federation = Mock(spec=[
             "make_query",
+            "register_edu_handler",
         ])
 
         self.query_handlers = {}
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index f1f664275f..979cebf600 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -39,6 +39,7 @@ class ProfileTestCase(unittest.TestCase):
     def setUp(self):
         self.mock_federation = Mock(spec=[
             "make_query",
+            "register_edu_handler",
         ])
 
         self.query_handlers = {}
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index c718d1f98f..f88d2be7c5 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -75,6 +75,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
                 "get_received_txn_response",
                 "set_received_txn_response",
                 "get_destination_retry_timings",
+                "get_devices_by_remote",
             ]),
             state_handler=self.state_handler,
             handlers=None,
@@ -99,6 +100,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
             defer.succeed(retry_timings_res)
         )
 
+        self.datastore.get_devices_by_remote.return_value = (0, [])
+
         def get_received_txn_response(*args):
             return defer.succeed(None)
         self.datastore.get_received_txn_response = get_received_txn_response
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 44e859b5d1..6acb8ab758 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -60,7 +60,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
 
     @defer.inlineCallbacks
     def test_room_members(self):
-        create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
+        yield self.persist(type="m.room.create", key="", creator=USER_ID)
         yield self.replicate()
         yield self.check("get_rooms_for_user", (USER_ID,), [])
         yield self.check("get_users_in_room", (ROOM_ID,), [])
@@ -95,15 +95,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         )])
         yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2])
 
-        # Join the room clobbering the state.
-        # This should remove any evidence of the other user being in the room.
         yield self.persist(
             type="m.room.member", key=USER_ID, membership="join",
-            reset_state=[create]
         )
         yield self.replicate()
-        yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID])
-        yield self.check("get_rooms_for_user", (USER_ID_2,), [])
+        yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2, USER_ID])
 
     @defer.inlineCallbacks
     def test_get_latest_event_ids_in_room(self):
@@ -123,51 +119,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         )
 
     @defer.inlineCallbacks
-    def test_get_current_state(self):
-        # Create the room.
-        create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
-        yield self.replicate()
-        yield self.check(
-            "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), []
-        )
-
-        # Join the room.
-        join1 = yield self.persist(
-            type="m.room.member", key=USER_ID, membership="join",
-        )
-        yield self.replicate()
-        yield self.check(
-            "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
-            [join1]
-        )
-
-        # Add some other user to the room.
-        join2 = yield self.persist(
-            type="m.room.member", key=USER_ID_2, membership="join",
-        )
-        yield self.replicate()
-        yield self.check(
-            "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
-            [join2]
-        )
-
-        # Leave the room, then rejoin the room clobbering state.
-        yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
-        join3 = yield self.persist(
-            type="m.room.member", key=USER_ID, membership="join",
-            reset_state=[create]
-        )
-        yield self.replicate()
-        yield self.check(
-            "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
-            []
-        )
-        yield self.check(
-            "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
-            [join3]
-        )
-
-    @defer.inlineCallbacks
     def test_redactions(self):
         yield self.persist(type="m.room.create", key="", creator=USER_ID)
         yield self.persist(type="m.room.member", key=USER_ID, membership="join")
@@ -283,6 +234,12 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         if depth is None:
             depth = self.event_id
 
+        if not prev_events:
+            latest_event_ids = yield self.master_store.get_latest_event_ids_in_room(
+                room_id
+            )
+            prev_events = [(ev_id, {}) for ev_id in latest_event_ids]
+
         event_dict = {
             "sender": sender,
             "type": type,
@@ -309,12 +266,15 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             state_ids = {
                 key: e.event_id for key, e in state.items()
             }
+            context = EventContext()
+            context.current_state_ids = state_ids
+            context.prev_state_ids = state_ids
+        elif not backfill:
+            state_handler = self.hs.get_state_handler()
+            context = yield state_handler.compute_event_context(event)
         else:
-            state_ids = None
+            context = EventContext()
 
-        context = EventContext()
-        context.current_state_ids = state_ids
-        context.prev_state_ids = state_ids
         context.push_actions = push_actions
 
         ordering = None
@@ -324,7 +284,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
             )
         else:
             ordering, _ = yield self.master_store.persist_event(
-                event, context, current_state=reset_state
+                event, context,
             )
 
         if ordering:
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 6bce352c5f..d746ea8568 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
 
     @defer.inlineCallbacks
     def test_topo_token_is_accepted(self):
-        token = "t1-0_0_0_0_0_0_0"
+        token = "t1-0_0_0_0_0_0_0_0"
         (code, response) = yield self.mock_resource.trigger_get(
             "/rooms/%s/messages?access_token=x&from=%s" %
             (self.room_id, token))
@@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
 
     @defer.inlineCallbacks
     def test_stream_token_is_accepted_for_fwd_pagianation(self):
-        token = "s0_0_0_0_0_0_0"
+        token = "s0_0_0_0_0_0_0_0"
         (code, response) = yield self.mock_resource.trigger_get(
             "/rooms/%s/messages?access_token=x&from=%s" %
             (self.room_id, token))
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 9ff1abcd80..9e98d0e330 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -39,7 +39,11 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
             event_cache_size=1,
             password_providers=[],
         )
-        hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
+        hs = yield setup_test_homeserver(
+            config=config,
+            federation_sender=Mock(),
+            replication_layer=Mock(),
+        )
 
         self.as_token = "token1"
         self.as_url = "some_url"
@@ -112,7 +116,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
             event_cache_size=1,
             password_providers=[],
         )
-        hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
+        hs = yield setup_test_homeserver(
+            config=config,
+            federation_sender=Mock(),
+            replication_layer=Mock(),
+        )
         self.db_pool = hs.get_db_pool()
 
         self.as_list = [
@@ -446,7 +454,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs = yield setup_test_homeserver(
             config=config,
             datastore=Mock(),
-            federation_sender=Mock()
+            federation_sender=Mock(),
+            replication_layer=Mock(),
         )
 
         ApplicationServiceStore(hs)
@@ -463,7 +472,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs = yield setup_test_homeserver(
             config=config,
             datastore=Mock(),
-            federation_sender=Mock()
+            federation_sender=Mock(),
+            replication_layer=Mock(),
         )
 
         with self.assertRaises(ConfigError) as cm:
@@ -486,7 +496,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         hs = yield setup_test_homeserver(
             config=config,
             datastore=Mock(),
-            federation_sender=Mock()
+            federation_sender=Mock(),
+            replication_layer=Mock(),
         )
 
         with self.assertRaises(ConfigError) as cm:
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 453bc61438..bfa6294250 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -35,6 +35,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
         now = 1470174257070
         json = '{ "key": "value" }'
 
+        yield self.store.store_device(
+            "user", "device", None
+        )
+
         yield self.store.set_e2e_device_keys(
             "user", "device", now, json)
 
@@ -71,6 +75,19 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
     def test_multiple_devices(self):
         now = 1470174257070
 
+        yield self.store.store_device(
+            "user1", "device1", None
+        )
+        yield self.store.store_device(
+            "user1", "device2", None
+        )
+        yield self.store.store_device(
+            "user2", "device1", None
+        )
+        yield self.store.store_device(
+            "user2", "device2", None
+        )
+
         yield self.store.set_e2e_device_keys(
             "user1", "device1", now, 'json11')
         yield self.store.set_e2e_device_keys(