summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erikj@jki.re>2017-01-30 14:35:21 +0000
committerGitHub <noreply@github.com>2017-01-30 14:35:21 +0000
commit9636b2407d60ab544d8ea713800132e203967a11 (patch)
treecca8f817c4aeee8b26f7840e877a66a7a6313384 /synapse
parentMerge pull request #1852 from matrix-org/paul/issue-1382 (diff)
parentRename func (diff)
downloadsynapse-9636b2407d60ab544d8ea713800132e203967a11.tar.xz
Merge pull request #1857 from matrix-org/erikj/device_list_stream
Implement device lists updates over federation
Diffstat (limited to '')
-rw-r--r--synapse/app/federation_sender.py3
-rw-r--r--synapse/app/synchrotron.py26
-rw-r--r--synapse/federation/federation_client.py10
-rw-r--r--synapse/federation/federation_server.py3
-rw-r--r--synapse/federation/transaction_queue.py133
-rw-r--r--synapse/federation/transport/client.py26
-rw-r--r--synapse/federation/transport/server.py8
-rw-r--r--synapse/handlers/device.py141
-rw-r--r--synapse/handlers/e2e_keys.py43
-rw-r--r--synapse/handlers/sync.py28
-rw-r--r--synapse/replication/resource.py20
-rw-r--r--synapse/replication/slave/storage/devices.py72
-rw-r--r--synapse/rest/client/v2_alpha/sync.py6
-rw-r--r--synapse/storage/__init__.py11
-rw-r--r--synapse/storage/_base.py6
-rw-r--r--synapse/storage/devices.py455
-rw-r--r--synapse/storage/end_to_end_keys.py31
-rw-r--r--synapse/storage/schema/delta/40/device_list_streams.sql59
-rw-r--r--synapse/streams/events.py4
-rw-r--r--synapse/types.py2
20 files changed, 983 insertions, 104 deletions
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index ec06620efb..411e47d98d 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -30,6 +30,7 @@ from synapse.replication.slave.storage.events import SlavedEventStore
 from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
 from synapse.replication.slave.storage.registration import SlavedRegistrationStore
 from synapse.replication.slave.storage.transactions import TransactionStore
+from synapse.replication.slave.storage.devices import SlavedDeviceStore
 from synapse.storage.engines import create_engine
 from synapse.storage.presence import UserPresenceState
 from synapse.util.async import sleep
@@ -56,7 +57,7 @@ logger = logging.getLogger("synapse.app.appservice")
 
 class FederationSenderSlaveStore(
     SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
-    SlavedRegistrationStore,
+    SlavedRegistrationStore, SlavedDeviceStore,
 ):
     pass
 
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 4dfc2dc648..b3fb408cfd 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -39,6 +39,7 @@ from synapse.replication.slave.storage.filtering import SlavedFilteringStore
 from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
 from synapse.replication.slave.storage.presence import SlavedPresenceStore
 from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
+from synapse.replication.slave.storage.devices import SlavedDeviceStore
 from synapse.replication.slave.storage.room import RoomStore
 from synapse.server import HomeServer
 from synapse.storage.client_ips import ClientIpStore
@@ -77,6 +78,7 @@ class SynchrotronSlavedStore(
     SlavedFilteringStore,
     SlavedPresenceStore,
     SlavedDeviceInboxStore,
+    SlavedDeviceStore,
     RoomStore,
     BaseSlavedStore,
     ClientIpStore,  # After BaseSlavedStore because the constructor is different
@@ -380,6 +382,27 @@ class SynchrotronServer(HomeServer):
                         stream_key, position, users=users, rooms=rooms
                     )
 
+        @defer.inlineCallbacks
+        def notify_device_list_update(result):
+            stream = result.get("device_lists")
+            if not stream:
+                return
+
+            position_index = stream["field_names"].index("position")
+            user_index = stream["field_names"].index("user_id")
+
+            for row in stream["rows"]:
+                position = row[position_index]
+                user_id = row[user_index]
+
+                rooms = yield store.get_rooms_for_user(user_id)
+                room_ids = [r.room_id for r in rooms]
+
+                notifier.on_new_event(
+                    "device_list_key", position, rooms=room_ids,
+                )
+
+        @defer.inlineCallbacks
         def notify(result):
             stream = result.get("events")
             if stream:
@@ -417,6 +440,7 @@ class SynchrotronServer(HomeServer):
             notify_from_stream(
                 result, "to_device", "to_device_key", user="user_id"
             )
+            yield notify_device_list_update(result)
 
         while True:
             try:
@@ -427,7 +451,7 @@ class SynchrotronServer(HomeServer):
                 yield store.process_replication(result)
                 typing_handler.process_replication(result)
                 yield presence_handler.process_replication(result)
-                notify(result)
+                yield notify(result)
             except:
                 logger.exception("Error replicating from %r", replication_url)
                 yield sleep(5)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index c9175bb33d..b5bcfd705a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -127,6 +127,16 @@ class FederationClient(FederationBase):
         )
 
     @log_function
+    def query_user_devices(self, destination, user_id, timeout=30000):
+        """Query the device keys for a list of user ids hosted on a remote
+        server.
+        """
+        sent_queries_counter.inc("user_devices")
+        return self.transport_layer.query_user_devices(
+            destination, user_id, timeout
+        )
+
+    @log_function
     def claim_client_keys(self, destination, content, timeout):
         """Claims one-time keys for a device hosted on a remote server.
 
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 862ccbef5d..e922b7ff4a 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -416,6 +416,9 @@ class FederationServer(FederationBase):
     def on_query_client_keys(self, origin, content):
         return self.on_query_request("client_keys", content)
 
+    def on_query_user_devices(self, origin, user_id):
+        return self.on_query_request("user_devices", user_id)
+
     @defer.inlineCallbacks
     @log_function
     def on_claim_client_keys(self, origin, content):
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 6b3a7abb9e..d18f6b6cfd 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -100,6 +100,7 @@ class TransactionQueue(object):
         self.pending_failures_by_dest = {}
 
         self.last_device_stream_id_by_dest = {}
+        self.last_device_list_stream_id_by_dest = {}
 
         # HACK to get unique tx id
         self._next_txn_id = int(self.clock.time_msec())
@@ -305,62 +306,74 @@ class TransactionQueue(object):
             yield run_on_reactor()
 
             while True:
-                    pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
-                    pending_edus = self.pending_edus_by_dest.pop(destination, [])
-                    pending_presence = self.pending_presence_by_dest.pop(destination, {})
-                    pending_failures = self.pending_failures_by_dest.pop(destination, [])
+                pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
+                pending_edus = self.pending_edus_by_dest.pop(destination, [])
+                pending_presence = self.pending_presence_by_dest.pop(destination, {})
+                pending_failures = self.pending_failures_by_dest.pop(destination, [])
 
-                    pending_edus.extend(
-                        self.pending_edus_keyed_by_dest.pop(destination, {}).values()
-                    )
+                pending_edus.extend(
+                    self.pending_edus_keyed_by_dest.pop(destination, {}).values()
+                )
 
-                    limiter = yield get_retry_limiter(
-                        destination,
-                        self.clock,
-                        self.store,
-                    )
+                limiter = yield get_retry_limiter(
+                    destination,
+                    self.clock,
+                    self.store,
+                )
 
-                    device_message_edus, device_stream_id = (
-                        yield self._get_new_device_messages(destination)
-                    )
+                device_message_edus, device_stream_id, dev_list_id = (
+                    yield self._get_new_device_messages(destination)
+                )
 
-                    pending_edus.extend(device_message_edus)
-                    if pending_presence:
-                        pending_edus.append(
-                            Edu(
-                                origin=self.server_name,
-                                destination=destination,
-                                edu_type="m.presence",
-                                content={
-                                    "push": [
-                                        format_user_presence_state(
-                                            presence, self.clock.time_msec()
-                                        )
-                                        for presence in pending_presence.values()
-                                    ]
-                                },
-                            )
+                pending_edus.extend(device_message_edus)
+                if pending_presence:
+                    pending_edus.append(
+                        Edu(
+                            origin=self.server_name,
+                            destination=destination,
+                            edu_type="m.presence",
+                            content={
+                                "push": [
+                                    format_user_presence_state(
+                                        presence, self.clock.time_msec()
+                                    )
+                                    for presence in pending_presence.values()
+                                ]
+                            },
                         )
+                    )
+
+                if pending_pdus:
+                    logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+                                 destination, len(pending_pdus))
 
-                    if pending_pdus:
-                        logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
-                                     destination, len(pending_pdus))
+                if not pending_pdus and not pending_edus and not pending_failures:
+                    logger.debug("TX [%s] Nothing to send", destination)
+                    self.last_device_stream_id_by_dest[destination] = (
+                        device_stream_id
+                    )
+                    return
 
-                    if not pending_pdus and not pending_edus and not pending_failures:
-                        logger.debug("TX [%s] Nothing to send", destination)
-                        self.last_device_stream_id_by_dest[destination] = (
-                            device_stream_id
+                success = yield self._send_new_transaction(
+                    destination, pending_pdus, pending_edus, pending_failures,
+                    limiter=limiter,
+                )
+                if success:
+                    # Remove the acknowledged device messages from the database
+                    # Only bother if we actually sent some device messages
+                    if device_message_edus:
+                        yield self.store.delete_device_msgs_for_remote(
+                            destination, device_stream_id
+                        )
+                        logger.info("Marking as sent %r %r", destination, dev_list_id)
+                        yield self.store.mark_as_sent_devices_by_remote(
+                            destination, dev_list_id
                         )
-                        return
 
-                    success = yield self._send_new_transaction(
-                        destination, pending_pdus, pending_edus, pending_failures,
-                        device_stream_id,
-                        should_delete_from_device_stream=bool(device_message_edus),
-                        limiter=limiter,
-                    )
-                    if not success:
-                        break
+                    self.last_device_stream_id_by_dest[destination] = device_stream_id
+                    self.last_device_list_stream_id_by_dest[destination] = dev_list_id
+                else:
+                    break
         except NotRetryingDestination:
             logger.debug(
                 "TX [%s] not ready for retry yet - "
@@ -387,13 +400,26 @@ class TransactionQueue(object):
             )
             for content in contents
         ]
-        defer.returnValue((edus, stream_id))
+
+        last_device_list = self.last_device_list_stream_id_by_dest.get(destination, 0)
+        now_stream_id, results = yield self.store.get_devices_by_remote(
+            destination, last_device_list
+        )
+        edus.extend(
+            Edu(
+                origin=self.server_name,
+                destination=destination,
+                edu_type="m.device_list_update",
+                content=content,
+            )
+            for content in results
+        )
+        defer.returnValue((edus, stream_id, now_stream_id))
 
     @measure_func("_send_new_transaction")
     @defer.inlineCallbacks
     def _send_new_transaction(self, destination, pending_pdus, pending_edus,
-                              pending_failures, device_stream_id,
-                              should_delete_from_device_stream, limiter):
+                              pending_failures, limiter):
 
         # Sort based on the order field
         pending_pdus.sort(key=lambda t: t[1])
@@ -504,13 +530,6 @@ class TransactionQueue(object):
                         "Failed to send event %s to %s", p.event_id, destination
                     )
                 success = False
-            else:
-                # Remove the acknowledged device messages from the database
-                if should_delete_from_device_stream:
-                    yield self.store.delete_device_msgs_for_remote(
-                        destination, device_stream_id
-                    )
-                self.last_device_stream_id_by_dest[destination] = device_stream_id
         except RuntimeError as e:
             # We capture this here as there as nothing actually listens
             # for this finishing functions deferred.
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 915af34409..f49e8a2cc4 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -348,6 +348,32 @@ class TransportLayerClient(object):
 
     @defer.inlineCallbacks
     @log_function
+    def query_user_devices(self, destination, user_id, timeout):
+        """Query the devices for a user id hosted on a remote server.
+
+        Response:
+            {
+              "stream_id": "...",
+              "devices": [ { ... } ]
+            }
+
+        Args:
+            destination(str): The server to query.
+            query_content(dict): The user ids to query.
+        Returns:
+            A dict containg the device keys.
+        """
+        path = PREFIX + "/user/devices/" + user_id
+
+        content = yield self.client.get_json(
+            destination=destination,
+            path=path,
+            timeout=timeout,
+        )
+        defer.returnValue(content)
+
+    @defer.inlineCallbacks
+    @log_function
     def claim_client_keys(self, destination, query_content, timeout):
         """Claim one-time keys for a list of devices hosted on a remote server.
 
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 159dbd1747..c840da834c 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -409,6 +409,13 @@ class FederationClientKeysQueryServlet(BaseFederationServlet):
         return self.handler.on_query_client_keys(origin, content)
 
 
+class FederationUserDevicesQueryServlet(BaseFederationServlet):
+    PATH = "/user/devices/(?P<user_id>[^/]*)"
+
+    def on_GET(self, origin, content, query, user_id):
+        return self.handler.on_query_user_devices(origin, user_id)
+
+
 class FederationClientKeysClaimServlet(BaseFederationServlet):
     PATH = "/user/keys/claim"
 
@@ -613,6 +620,7 @@ SERVLET_CLASSES = (
     FederationGetMissingEventsServlet,
     FederationEventAuthServlet,
     FederationClientKeysQueryServlet,
+    FederationUserDevicesQueryServlet,
     FederationClientKeysClaimServlet,
     FederationThirdPartyInviteExchangeServlet,
     On3pidBindServlet,
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index aa68755936..6fefb85890 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -15,6 +15,8 @@
 
 from synapse.api import errors
 from synapse.util import stringutils
+from synapse.util.async import Linearizer
+from synapse.types import get_domain_from_id
 from twisted.internet import defer
 from ._base import BaseHandler
 
@@ -27,6 +29,21 @@ class DeviceHandler(BaseHandler):
     def __init__(self, hs):
         super(DeviceHandler, self).__init__(hs)
 
+        self.hs = hs
+        self.state = hs.get_state_handler()
+        self.federation_sender = hs.get_federation_sender()
+        self.federation = hs.get_replication_layer()
+        self._remote_edue_linearizer = Linearizer(name="remote_device_list")
+
+        self.federation.register_edu_handler(
+            "m.device_list_update", self._incoming_device_list_update,
+        )
+        self.federation.register_query_handler(
+            "user_devices", self.on_federation_query_user_devices,
+        )
+
+        hs.get_distributor().observe("user_left_room", self.user_left_room)
+
     @defer.inlineCallbacks
     def check_device_registered(self, user_id, device_id,
                                 initial_device_display_name=None):
@@ -45,29 +62,29 @@ class DeviceHandler(BaseHandler):
             str: device id (generated if none was supplied)
         """
         if device_id is not None:
-            yield self.store.store_device(
+            new_device = yield self.store.store_device(
                 user_id=user_id,
                 device_id=device_id,
                 initial_device_display_name=initial_device_display_name,
-                ignore_if_known=True,
             )
+            if new_device:
+                yield self.notify_device_update(user_id, [device_id])
             defer.returnValue(device_id)
 
         # if the device id is not specified, we'll autogen one, but loop a few
         # times in case of a clash.
         attempts = 0
         while attempts < 5:
-            try:
-                device_id = stringutils.random_string(10).upper()
-                yield self.store.store_device(
-                    user_id=user_id,
-                    device_id=device_id,
-                    initial_device_display_name=initial_device_display_name,
-                    ignore_if_known=False,
-                )
+            device_id = stringutils.random_string(10).upper()
+            new_device = yield self.store.store_device(
+                user_id=user_id,
+                device_id=device_id,
+                initial_device_display_name=initial_device_display_name,
+            )
+            if new_device:
+                yield self.notify_device_update(user_id, [device_id])
                 defer.returnValue(device_id)
-            except errors.StoreError:
-                attempts += 1
+            attempts += 1
 
         raise errors.StoreError(500, "Couldn't generate a device ID.")
 
@@ -147,6 +164,8 @@ class DeviceHandler(BaseHandler):
             user_id=user_id, device_id=device_id
         )
 
+        yield self.notify_device_update(user_id, [device_id])
+
     @defer.inlineCallbacks
     def update_device(self, user_id, device_id, content):
         """ Update the given device
@@ -166,12 +185,110 @@ class DeviceHandler(BaseHandler):
                 device_id,
                 new_display_name=content.get("display_name")
             )
+            yield self.notify_device_update(user_id, [device_id])
         except errors.StoreError, e:
             if e.code == 404:
                 raise errors.NotFoundError()
             else:
                 raise
 
+    @defer.inlineCallbacks
+    def notify_device_update(self, user_id, device_ids):
+        """Notify that a user's device(s) has changed. Pokes the notifier, and
+        remote servers if the user is local.
+        """
+        rooms = yield self.store.get_rooms_for_user(user_id)
+        room_ids = [r.room_id for r in rooms]
+
+        hosts = set()
+        if self.hs.is_mine_id(user_id):
+            for room_id in room_ids:
+                users = yield self.state.get_current_user_in_room(room_id)
+                hosts.update(get_domain_from_id(u) for u in users)
+            hosts.discard(self.server_name)
+
+        position = yield self.store.add_device_change_to_streams(
+            user_id, device_ids, list(hosts)
+        )
+
+        yield self.notifier.on_new_event(
+            "device_list_key", position, rooms=room_ids,
+        )
+
+        if hosts:
+            logger.info("Sending device list update notif to: %r", hosts)
+            for host in hosts:
+                self.federation_sender.send_device_messages(host)
+
+    @defer.inlineCallbacks
+    def _incoming_device_list_update(self, origin, edu_content):
+        user_id = edu_content["user_id"]
+        device_id = edu_content["device_id"]
+        stream_id = edu_content["stream_id"]
+        prev_ids = edu_content.get("prev_id", [])
+
+        if get_domain_from_id(user_id) != origin:
+            # TODO: Raise?
+            logger.warning("Got device list update edu for %r from %r", user_id, origin)
+            return
+
+        rooms = yield self.store.get_rooms_for_user(user_id)
+        if not rooms:
+            # We don't share any rooms with this user. Ignore update, as we
+            # probably won't get any further updates.
+            return
+
+        with (yield self._remote_edue_linearizer.queue(user_id)):
+            # If the prev id matches whats in our cache table, then we don't need
+            # to resync the users device list, otherwise we do.
+            resync = True
+            if len(prev_ids) == 1:
+                extremity = yield self.store.get_device_list_last_stream_id_for_remote(
+                    user_id
+                )
+                logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids)
+                if str(extremity) == str(prev_ids[0]):
+                    resync = False
+
+            if resync:
+                # Fetch all devices for the user.
+                result = yield self.federation.query_user_devices(origin, user_id)
+                stream_id = result["stream_id"]
+                devices = result["devices"]
+                yield self.store.update_remote_device_list_cache(
+                    user_id, devices, stream_id,
+                )
+                device_ids = [device["device_id"] for device in devices]
+                yield self.notify_device_update(user_id, device_ids)
+            else:
+                # Simply update the single device, since we know that is the only
+                # change (becuase of the single prev_id matching the current cache)
+                content = dict(edu_content)
+                for key in ("user_id", "device_id", "stream_id", "prev_ids"):
+                    content.pop(key, None)
+                yield self.store.update_remote_device_list_cache_entry(
+                    user_id, device_id, content, stream_id,
+                )
+                yield self.notify_device_update(user_id, [device_id])
+
+    @defer.inlineCallbacks
+    def on_federation_query_user_devices(self, user_id):
+        stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
+        defer.returnValue({
+            "user_id": user_id,
+            "stream_id": stream_id,
+            "devices": devices,
+        })
+
+    @defer.inlineCallbacks
+    def user_left_room(self, user, room_id):
+        user_id = user.to_string()
+        rooms = yield self.store.get_rooms_for_user(user_id)
+        if not rooms:
+            # We no longer share rooms with this user, so we'll no longer
+            # receive device updates. Mark this in DB.
+            yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
+
 
 def _update_device_from_client_ips(device, client_ips):
     ip = client_ips.get((device["user_id"], device["device_id"]), {})
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index b63a660c06..a16b9def8d 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -73,10 +73,9 @@ class E2eKeysHandler(object):
             if self.is_mine_id(user_id):
                 local_query[user_id] = device_ids
             else:
-                domain = get_domain_from_id(user_id)
-                remote_queries.setdefault(domain, {})[user_id] = device_ids
+                remote_queries[user_id] = device_ids
 
-        # do the queries
+        # Firt get local devices.
         failures = {}
         results = {}
         if local_query:
@@ -85,9 +84,42 @@ class E2eKeysHandler(object):
                 if user_id in local_query:
                     results[user_id] = keys
 
+        # Now attempt to get any remote devices from our local cache.
+        remote_queries_not_in_cache = {}
+        if remote_queries:
+            query_list = []
+            for user_id, device_ids in remote_queries.iteritems():
+                if device_ids:
+                    query_list.extend((user_id, device_id) for device_id in device_ids)
+                else:
+                    query_list.append((user_id, None))
+
+            user_ids_not_in_cache, remote_results = (
+                yield self.store.get_user_devices_from_cache(
+                    query_list
+                )
+            )
+            for user_id, devices in remote_results.iteritems():
+                user_devices = results.setdefault(user_id, {})
+                for device_id, device in devices.iteritems():
+                    keys = device.get("keys", None)
+                    device_display_name = device.get("device_display_name", None)
+                    if keys:
+                        result = dict(keys)
+                        unsigned = result.setdefault("unsigned", {})
+                        if device_display_name:
+                            unsigned["device_display_name"] = device_display_name
+                        user_devices[device_id] = result
+
+            for user_id in user_ids_not_in_cache:
+                domain = get_domain_from_id(user_id)
+                r = remote_queries_not_in_cache.setdefault(domain, {})
+                r[user_id] = remote_queries[user_id]
+
+        # Now fetch any devices that we don't have in our cache
         @defer.inlineCallbacks
         def do_remote_query(destination):
-            destination_query = remote_queries[destination]
+            destination_query = remote_queries_not_in_cache[destination]
             try:
                 limiter = yield get_retry_limiter(
                     destination, self.clock, self.store
@@ -119,7 +151,7 @@ class E2eKeysHandler(object):
 
         yield preserve_context_over_deferred(defer.gatherResults([
             preserve_fn(do_remote_query)(destination)
-            for destination in remote_queries
+            for destination in remote_queries_not_in_cache
         ]))
 
         defer.returnValue({
@@ -259,6 +291,7 @@ class E2eKeysHandler(object):
                 user_id, device_id, time_now,
                 encode_canonical_json(device_keys)
             )
+            yield self.device_handler.notify_device_update(user_id, [device_id])
 
         one_time_keys = keys.get("one_time_keys", None)
         if one_time_keys:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c880f61685..9199f20817 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -115,6 +115,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
     "invited",  # InvitedSyncResult for each invited room.
     "archived",  # ArchivedSyncResult for each archived room.
     "to_device",  # List of direct messages for the device.
+    "device_lists",  # List of user_ids whose devices have chanegd
 ])):
     __slots__ = []
 
@@ -544,6 +545,10 @@ class SyncHandler(object):
 
         yield self._generate_sync_entry_for_to_device(sync_result_builder)
 
+        device_lists = yield self._generate_sync_entry_for_device_list(
+            sync_result_builder
+        )
+
         defer.returnValue(SyncResult(
             presence=sync_result_builder.presence,
             account_data=sync_result_builder.account_data,
@@ -551,10 +556,33 @@ class SyncHandler(object):
             invited=sync_result_builder.invited,
             archived=sync_result_builder.archived,
             to_device=sync_result_builder.to_device,
+            device_lists=device_lists,
             next_batch=sync_result_builder.now_token,
         ))
 
     @defer.inlineCallbacks
+    def _generate_sync_entry_for_device_list(self, sync_result_builder):
+        user_id = sync_result_builder.sync_config.user.to_string()
+        since_token = sync_result_builder.since_token
+
+        if since_token and since_token.device_list_key:
+            rooms = yield self.store.get_rooms_for_user(user_id)
+            room_ids = set(r.room_id for r in rooms)
+
+            user_ids_changed = set()
+            changed = yield self.store.get_user_whose_devices_changed(
+                since_token.device_list_key
+            )
+            for other_user_id in changed:
+                other_rooms = yield self.store.get_rooms_for_user(other_user_id)
+                if room_ids.intersection(e.room_id for e in other_rooms):
+                    user_ids_changed.add(other_user_id)
+
+            defer.returnValue(user_ids_changed)
+        else:
+            defer.returnValue([])
+
+    @defer.inlineCallbacks
     def _generate_sync_entry_for_to_device(self, sync_result_builder):
         """Generates the portion of the sync response. Populates
         `sync_result_builder` with the result.
diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py
index 4616e9b34a..a30e647474 100644
--- a/synapse/replication/resource.py
+++ b/synapse/replication/resource.py
@@ -46,6 +46,7 @@ STREAM_NAMES = (
     ("to_device",),
     ("public_rooms",),
     ("federation",),
+    ("device_lists",),
 )
 
 
@@ -140,6 +141,7 @@ class ReplicationResource(Resource):
         caches_token = self.store.get_cache_stream_token()
         public_rooms_token = self.store.get_current_public_room_stream_id()
         federation_token = self.federation_sender.get_current_token()
+        device_list_token = self.store.get_device_stream_token()
 
         defer.returnValue(_ReplicationToken(
             room_stream_token,
@@ -155,6 +157,7 @@ class ReplicationResource(Resource):
             int(stream_token.to_device_key),
             int(public_rooms_token),
             int(federation_token),
+            int(device_list_token),
         ))
 
     @request_handler()
@@ -214,6 +217,7 @@ class ReplicationResource(Resource):
         yield self.caches(writer, current_token, limit, request_streams)
         yield self.to_device(writer, current_token, limit, request_streams)
         yield self.public_rooms(writer, current_token, limit, request_streams)
+        yield self.device_lists(writer, current_token, limit, request_streams)
         self.federation(writer, current_token, limit, request_streams, federation_ack)
         self.streams(writer, current_token, request_streams)
 
@@ -495,6 +499,20 @@ class ReplicationResource(Resource):
                 "position", "type", "content",
             ), position=upto_token)
 
+    @defer.inlineCallbacks
+    def device_lists(self, writer, current_token, limit, request_streams):
+        current_position = current_token.device_lists
+
+        device_lists = request_streams.get("device_lists")
+
+        if device_lists is not None and device_lists != current_position:
+            changes = yield self.store.get_all_device_list_changes_for_remotes(
+                device_lists,
+            )
+            writer.write_header_and_rows("device_lists", changes, (
+                "position", "user_id", "destination",
+            ), position=current_position)
+
 
 class _Writer(object):
     """Writes the streams as a JSON object as the response to the request"""
@@ -527,7 +545,7 @@ class _Writer(object):
 class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
     "events", "presence", "typing", "receipts", "account_data", "backfill",
     "push_rules", "pushers", "state", "caches", "to_device", "public_rooms",
-    "federation",
+    "federation", "device_lists",
 ))):
     __slots__ = []
 
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
new file mode 100644
index 0000000000..ca46aa17b6
--- /dev/null
+++ b/synapse/replication/slave/storage/devices.py
@@ -0,0 +1,72 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseSlavedStore
+from ._slaved_id_tracker import SlavedIdTracker
+from synapse.storage import DataStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+
+class SlavedDeviceStore(BaseSlavedStore):
+    def __init__(self, db_conn, hs):
+        super(SlavedDeviceStore, self).__init__(db_conn, hs)
+
+        self.hs = hs
+
+        self._device_list_id_gen = SlavedIdTracker(
+            db_conn, "device_lists_stream", "stream_id",
+        )
+        device_list_max = self._device_list_id_gen.get_current_token()
+        self._device_list_stream_cache = StreamChangeCache(
+            "DeviceListStreamChangeCache", device_list_max,
+        )
+        self._device_list_federation_stream_cache = StreamChangeCache(
+            "DeviceListFederationStreamChangeCache", device_list_max,
+        )
+
+    get_device_stream_token = DataStore.get_device_stream_token.__func__
+    get_user_whose_devices_changed = DataStore.get_user_whose_devices_changed.__func__
+    get_devices_by_remote = DataStore.get_devices_by_remote.__func__
+    _get_devices_by_remote_txn = DataStore._get_devices_by_remote_txn.__func__
+    _get_e2e_device_keys_txn = DataStore._get_e2e_device_keys_txn.__func__
+    mark_as_sent_devices_by_remote = DataStore.mark_as_sent_devices_by_remote.__func__
+    _mark_as_sent_devices_by_remote_txn = (
+        DataStore._mark_as_sent_devices_by_remote_txn.__func__
+    )
+
+    def stream_positions(self):
+        result = super(SlavedDeviceStore, self).stream_positions()
+        result["device_lists"] = self._device_list_id_gen.get_current_token()
+        return result
+
+    def process_replication(self, result):
+        stream = result.get("device_lists")
+        if stream:
+            self._device_list_id_gen.advance(int(stream["position"]))
+            for row in stream["rows"]:
+                stream_id = row[0]
+                user_id = row[1]
+                destination = row[2]
+
+                self._device_list_stream_cache.entity_has_changed(
+                    user_id, stream_id
+                )
+
+                if destination:
+                    self._device_list_federation_stream_cache.entity_has_changed(
+                        destination, stream_id
+                    )
+
+        return super(SlavedDeviceStore, self).process_replication(result)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 7199ec883a..b3d8001638 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -170,12 +170,16 @@ class SyncRestServlet(RestServlet):
         )
 
         archived = self.encode_archived(
-            sync_result.archived, time_now, requester.access_token_id, filter.event_fields
+            sync_result.archived, time_now, requester.access_token_id,
+            filter.event_fields,
         )
 
         response_content = {
             "account_data": {"events": sync_result.account_data},
             "to_device": {"events": sync_result.to_device},
+            "device_lists": {
+                "changed": list(sync_result.device_lists),
+            },
             "presence": self.encode_presence(
                 sync_result.presence, time_now
             ),
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index e8495f1eb9..b9968debe5 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -116,6 +116,9 @@ class DataStore(RoomMemberStore, RoomStore,
         self._public_room_id_gen = StreamIdGenerator(
             db_conn, "public_room_list_stream", "stream_id"
         )
+        self._device_list_id_gen = StreamIdGenerator(
+            db_conn, "device_lists_stream", "stream_id",
+        )
 
         self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
         self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
@@ -210,6 +213,14 @@ class DataStore(RoomMemberStore, RoomStore,
             prefilled_cache=device_outbox_prefill,
         )
 
+        device_list_max = self._device_list_id_gen.get_current_token()
+        self._device_list_stream_cache = StreamChangeCache(
+            "DeviceListStreamChangeCache", device_list_max,
+        )
+        self._device_list_federation_stream_cache = StreamChangeCache(
+            "DeviceListFederationStreamChangeCache", device_list_max,
+        )
+
         cur = LoggingTransaction(
             db_conn.cursor(),
             name="_find_stream_orderings_for_times_txn",
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 963ef999d5..05374682fd 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -387,6 +387,10 @@ class SQLBaseStore(object):
         Args:
             table : string giving the table name
             values : dict of new column names and values for them
+
+        Returns:
+            bool: Whether the row was inserted or not. Only useful when
+            `or_ignore` is True
         """
         try:
             yield self.runInteraction(
@@ -398,6 +402,8 @@ class SQLBaseStore(object):
             # a cursor after we receive an error from the db.
             if not or_ignore:
                 raise
+            defer.returnValue(False)
+        defer.returnValue(True)
 
     @staticmethod
     def _simple_insert_txn(txn, table, values):
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 17920d4480..e68ee50152 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import ujson as json
 
 from twisted.internet import defer
 
@@ -23,27 +24,29 @@ logger = logging.getLogger(__name__)
 
 
 class DeviceStore(SQLBaseStore):
+    def __init__(self, hs):
+        super(DeviceStore, self).__init__(hs)
+
+        self._clock.looping_call(
+            self._prune_old_outbound_device_pokes, 60 * 60 * 1000
+        )
+
     @defer.inlineCallbacks
     def store_device(self, user_id, device_id,
-                     initial_device_display_name,
-                     ignore_if_known=True):
+                     initial_device_display_name):
         """Ensure the given device is known; add it to the store if not
 
         Args:
             user_id (str): id of user associated with the device
             device_id (str): id of device
             initial_device_display_name (str): initial displayname of the
-               device
-            ignore_if_known (bool): ignore integrity errors which mean the
-               device is already known
+               device. Ignored if device exists.
         Returns:
-            defer.Deferred
-        Raises:
-            StoreError: if ignore_if_known is False and the device was already
-               known
+            defer.Deferred: boolean whether the device was inserted or an
+                existing device existed with that ID.
         """
         try:
-            yield self._simple_insert(
+            inserted = yield self._simple_insert(
                 "devices",
                 values={
                     "user_id": user_id,
@@ -51,8 +54,9 @@ class DeviceStore(SQLBaseStore):
                     "display_name": initial_device_display_name
                 },
                 desc="store_device",
-                or_ignore=ignore_if_known,
+                or_ignore=True,
             )
+            defer.returnValue(inserted)
         except Exception as e:
             logger.error("store_device with device_id=%s(%r) user_id=%s(%r)"
                          " display_name=%s(%r) failed: %s",
@@ -139,3 +143,432 @@ class DeviceStore(SQLBaseStore):
         )
 
         defer.returnValue({d["device_id"]: d for d in devices})
+
+    def get_device_list_last_stream_id_for_remote(self, user_id):
+        """Get the last stream_id we got for a user. May be None if we haven't
+        got any information for them.
+        """
+        return self._simple_select_one_onecol(
+            table="device_lists_remote_extremeties",
+            keyvalues={"user_id": user_id},
+            retcol="stream_id",
+            desc="get_device_list_remote_extremity",
+            allow_none=True,
+        )
+
+    def mark_remote_user_device_list_as_unsubscribed(self, user_id):
+        """Mark that we no longer track device lists for remote user.
+        """
+        return self._simple_delete(
+            table="device_lists_remote_extremeties",
+            keyvalues={
+                "user_id": user_id,
+            },
+        )
+
+    def update_remote_device_list_cache_entry(self, user_id, device_id, content,
+                                              stream_id):
+        """Updates a single user's device in the cache.
+        """
+        return self.runInteraction(
+            "update_remote_device_list_cache_entry",
+            self._update_remote_device_list_cache_entry_txn,
+            user_id, device_id, content, stream_id,
+        )
+
+    def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
+                                                   content, stream_id):
+        self._simple_upsert_txn(
+            txn,
+            table="device_lists_remote_cache",
+            keyvalues={
+                "user_id": user_id,
+                "device_id": device_id,
+            },
+            values={
+                "content": json.dumps(content),
+            }
+        )
+
+        self._simple_upsert_txn(
+            txn,
+            table="device_lists_remote_extremeties",
+            keyvalues={
+                "user_id": user_id,
+            },
+            values={
+                "stream_id": stream_id,
+            }
+        )
+
+    def update_remote_device_list_cache(self, user_id, devices, stream_id):
+        """Replace the cache of the remote user's devices.
+        """
+        return self.runInteraction(
+            "update_remote_device_list_cache",
+            self._update_remote_device_list_cache_txn,
+            user_id, devices, stream_id,
+        )
+
+    def _update_remote_device_list_cache_txn(self, txn, user_id, devices,
+                                             stream_id):
+        self._simple_delete_txn(
+            txn,
+            table="device_lists_remote_cache",
+            keyvalues={
+                "user_id": user_id,
+            },
+        )
+
+        self._simple_insert_many_txn(
+            txn,
+            table="device_lists_remote_cache",
+            values=[
+                {
+                    "user_id": user_id,
+                    "device_id": content["device_id"],
+                    "content": json.dumps(content),
+                }
+                for content in devices
+            ]
+        )
+
+        self._simple_upsert_txn(
+            txn,
+            table="device_lists_remote_extremeties",
+            keyvalues={
+                "user_id": user_id,
+            },
+            values={
+                "stream_id": stream_id,
+            }
+        )
+
+    def get_devices_by_remote(self, destination, from_stream_id):
+        """Get stream of updates to send to remote servers
+
+        Returns:
+            (now_stream_id, [ { updates }, .. ])
+        """
+        now_stream_id = self._device_list_id_gen.get_current_token()
+
+        has_changed = self._device_list_federation_stream_cache.has_entity_changed(
+            destination, int(from_stream_id)
+        )
+        if not has_changed:
+            return (now_stream_id, [])
+
+        return self.runInteraction(
+            "get_devices_by_remote", self._get_devices_by_remote_txn,
+            destination, from_stream_id, now_stream_id,
+        )
+
+    def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
+                                   now_stream_id):
+        sql = """
+            SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
+            WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
+            GROUP BY user_id, device_id
+        """
+        txn.execute(
+            sql, (destination, from_stream_id, now_stream_id, False)
+        )
+        rows = txn.fetchall()
+
+        if not rows:
+            return (now_stream_id, [])
+
+        # maps (user_id, device_id) -> stream_id
+        query_map = {(r[0], r[1]): r[2] for r in rows}
+        devices = self._get_e2e_device_keys_txn(
+            txn, query_map.keys(), include_all_devices=True
+        )
+
+        prev_sent_id_sql = """
+            SELECT coalesce(max(stream_id), 0) as stream_id
+            FROM device_lists_outbound_pokes
+            WHERE destination = ? AND user_id = ? AND stream_id <= ?
+        """
+
+        results = []
+        for user_id, user_devices in devices.iteritems():
+            # The prev_id for the first row is always the last row before
+            # `from_stream_id`
+            txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
+            rows = txn.fetchall()
+            prev_id = rows[0][0]
+            for device_id, device in user_devices.iteritems():
+                stream_id = query_map[(user_id, device_id)]
+                result = {
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "prev_id": [prev_id] if prev_id else [],
+                    "stream_id": stream_id,
+                }
+
+                prev_id = stream_id
+
+                key_json = device.get("key_json", None)
+                if key_json:
+                    result["keys"] = json.loads(key_json)
+                device_display_name = device.get("device_display_name", None)
+                if device_display_name:
+                    result["device_display_name"] = device_display_name
+
+                results.append(result)
+
+        return (now_stream_id, results)
+
+    def get_user_devices_from_cache(self, query_list):
+        """Get the devices (and keys if any) for remote users from the cache.
+
+        Args:
+            query_list(list): List of (user_id, device_ids), if device_ids is
+                falsey then return all device ids for that user.
+
+        Returns:
+            (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
+            a set of user_ids and results_map is a mapping of
+            user_id -> device_id -> device_info
+        """
+        return self.runInteraction(
+            "get_user_devices_from_cache", self._get_user_devices_from_cache_txn,
+            query_list,
+        )
+
+    def _get_user_devices_from_cache_txn(self, txn, query_list):
+        user_ids = {user_id for user_id, _ in query_list}
+
+        user_ids_in_cache = set()
+        for user_id in user_ids:
+            stream_ids = self._simple_select_onecol_txn(
+                txn,
+                table="device_lists_remote_extremeties",
+                keyvalues={
+                    "user_id": user_id,
+                },
+                retcol="stream_id",
+            )
+            if stream_ids:
+                user_ids_in_cache.add(user_id)
+
+        user_ids_not_in_cache = user_ids - user_ids_in_cache
+
+        results = {}
+        for user_id, device_id in query_list:
+            if user_id not in user_ids_in_cache:
+                continue
+
+            if device_id:
+                content = self._simple_select_one_onecol_txn(
+                    txn,
+                    table="device_lists_remote_cache",
+                    keyvalues={
+                        "user_id": user_id,
+                        "device_id": device_id,
+                    },
+                    retcol="content",
+                )
+                results.setdefault(user_id, {})[device_id] = json.loads(content)
+            else:
+                devices = self._simple_select_list_txn(
+                    txn,
+                    table="device_lists_remote_cache",
+                    keyvalues={
+                        "user_id": user_id,
+                    },
+                    retcols=("device_id", "content"),
+                )
+                results[user_id] = {
+                    device["device_id"]: json.loads(device["content"])
+                    for device in devices
+                }
+                user_ids_in_cache.discard(user_id)
+
+        return user_ids_not_in_cache, results
+
+    def get_devices_with_keys_by_user(self, user_id):
+        """Get all devices (with any device keys) for a user
+
+        Returns:
+            (stream_id, devices)
+        """
+        return self.runInteraction(
+            "get_devices_with_keys_by_user",
+            self._get_devices_with_keys_by_user_txn, user_id,
+        )
+
+    def _get_devices_with_keys_by_user_txn(self, txn, user_id):
+        now_stream_id = self._device_list_id_gen.get_current_token()
+
+        devices = self._get_e2e_device_keys_txn(
+            txn, [(user_id, None)], include_all_devices=True
+        )
+
+        if devices:
+            user_devices = devices[user_id]
+            results = []
+            for device_id, device in user_devices.iteritems():
+                result = {
+                    "device_id": device_id,
+                }
+
+                key_json = device.get("key_json", None)
+                if key_json:
+                    result["keys"] = json.loads(key_json)
+                device_display_name = device.get("device_display_name", None)
+                if device_display_name:
+                    result["device_display_name"] = device_display_name
+
+                results.append(result)
+
+            return now_stream_id, results
+
+        return now_stream_id, []
+
+    def mark_as_sent_devices_by_remote(self, destination, stream_id):
+        """Mark that updates have successfully been sent to the destination.
+        """
+        return self.runInteraction(
+            "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
+            destination, stream_id,
+        )
+
+    def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
+        sql = """
+            DELETE FROM device_lists_outbound_pokes
+            WHERE destination = ? AND stream_id < (
+                SELECT coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes
+                WHERE destination = ? AND stream_id <= ?
+            )
+        """
+        txn.execute(sql, (destination, destination, stream_id,))
+
+        sql = """
+            UPDATE device_lists_outbound_pokes SET sent = ?
+            WHERE destination = ? AND stream_id <= ?
+        """
+        txn.execute(sql, (True, destination, stream_id,))
+
+    @defer.inlineCallbacks
+    def get_user_whose_devices_changed(self, from_key):
+        """Get set of users whose devices have changed since `from_key`.
+        """
+        from_key = int(from_key)
+        changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
+        if changed is not None:
+            defer.returnValue(set(changed))
+
+        sql = """
+            SELECT user_id FROM device_lists_stream WHERE stream_id > ?
+        """
+        rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
+        defer.returnValue(set(row["user_id"] for row in rows))
+
+    def get_all_device_list_changes_for_remotes(self, from_key):
+        """Return a list of `(stream_id, user_id, destination)` which is the
+        combined list of changes to devices, and which destinations need to be
+        poked. `destination` may be None if no destinations need to be poked.
+        """
+        sql = """
+            SELECT stream_id, user_id, destination FROM device_lists_stream
+            LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
+            WHERE stream_id > ?
+        """
+        return self._execute(
+            "get_users_and_hosts_device_list", None,
+            sql, from_key,
+        )
+
+    @defer.inlineCallbacks
+    def add_device_change_to_streams(self, user_id, device_ids, hosts):
+        """Persist that a user's devices have been updated, and which hosts
+        (if any) should be poked.
+        """
+        with self._device_list_id_gen.get_next() as stream_id:
+            yield self.runInteraction(
+                "add_device_change_to_streams", self._add_device_change_txn,
+                user_id, device_ids, hosts, stream_id,
+            )
+        defer.returnValue(stream_id)
+
+    def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
+        now = self._clock.time_msec()
+
+        txn.call_after(
+            self._device_list_stream_cache.entity_has_changed,
+            user_id, stream_id,
+        )
+        for host in hosts:
+            txn.call_after(
+                self._device_list_federation_stream_cache.entity_has_changed,
+                host, stream_id,
+            )
+
+        self._simple_insert_many_txn(
+            txn,
+            table="device_lists_stream",
+            values=[
+                {
+                    "stream_id": stream_id,
+                    "user_id": user_id,
+                    "device_id": device_id,
+                }
+                for device_id in device_ids
+            ]
+        )
+
+        self._simple_insert_many_txn(
+            txn,
+            table="device_lists_outbound_pokes",
+            values=[
+                {
+                    "destination": destination,
+                    "stream_id": stream_id,
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "sent": False,
+                    "ts": now,
+                }
+                for destination in hosts
+                for device_id in device_ids
+            ]
+        )
+
+    def get_device_stream_token(self):
+        return self._device_list_id_gen.get_current_token()
+
+    def _prune_old_outbound_device_pokes(self):
+        """Delete old entries out of the device_lists_outbound_pokes to ensure
+        that we don't fill up due to dead servers. We keep one entry per
+        (destination, user_id) tuple to ensure that the prev_ids remain correct
+        if the server does come back.
+        """
+        now = self._clock.time_msec()
+
+        def _prune_txn(txn):
+            select_sql = """
+                SELECT destination, user_id, max(stream_id) as stream_id
+                FROM device_lists_outbound_pokes
+                GROUP BY destination, user_id
+            """
+
+            txn.execute(select_sql)
+            rows = txn.fetchall()
+
+            delete_sql = """
+                DELETE FROM device_lists_outbound_pokes
+                WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ?
+            """
+
+            txn.executemany(
+                delete_sql,
+                (
+                    (now, row["destination"], row["user_id"], row["stream_id"])
+                    for row in rows
+                )
+            )
+
+        return self.runInteraction(
+            "_prune_old_outbound_device_pokes", _prune_txn
+        )
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 385d607056..85763f7ceb 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -12,9 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import collections
-
-import twisted.internet.defer
+from twisted.internet import defer
 
 from ._base import SQLBaseStore
 
@@ -33,10 +31,12 @@ class EndToEndKeyStore(SQLBaseStore):
             }
         )
 
-    def get_e2e_device_keys(self, query_list):
+    def get_e2e_device_keys(self, query_list, include_all_devices=False):
         """Fetch a list of device keys.
         Args:
             query_list(list): List of pairs of user_ids and device_ids.
+            include_all_devices (bool): whether to include entries for devices
+                that don't have device keys
         Returns:
             Dict mapping from user-id to dict mapping from device_id to
             dict containing "key_json", "device_display_name".
@@ -45,41 +45,42 @@ class EndToEndKeyStore(SQLBaseStore):
             return {}
 
         return self.runInteraction(
-            "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list
+            "get_e2e_device_keys", self._get_e2e_device_keys_txn,
+            query_list, include_all_devices,
         )
 
-    def _get_e2e_device_keys_txn(self, txn, query_list):
+    def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
         query_clauses = []
         query_params = []
 
         for (user_id, device_id) in query_list:
-            query_clause = "k.user_id = ?"
+            query_clause = "user_id = ?"
             query_params.append(user_id)
 
             if device_id:
-                query_clause += " AND k.device_id = ?"
+                query_clause += " AND device_id = ?"
                 query_params.append(device_id)
 
             query_clauses.append(query_clause)
 
         sql = (
-            "SELECT k.user_id, k.device_id, "
+            "SELECT user_id, device_id, "
             "    d.display_name AS device_display_name, "
             "    k.key_json"
-            " FROM e2e_device_keys_json k"
-            "    LEFT JOIN devices d ON d.user_id = k.user_id"
-            "      AND d.device_id = k.device_id"
+            " FROM devices d"
+            "    %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
             " WHERE %s"
         ) % (
+            "LEFT" if include_all_devices else "INNER",
             " OR ".join("(" + q + ")" for q in query_clauses)
         )
 
         txn.execute(sql, query_params)
         rows = self.cursor_to_dict(txn)
 
-        result = collections.defaultdict(dict)
+        result = {}
         for row in rows:
-            result[row["user_id"]][row["device_id"]] = row
+            result.setdefault(row["user_id"], {})[row["device_id"]] = row
 
         return result
 
@@ -152,7 +153,7 @@ class EndToEndKeyStore(SQLBaseStore):
             "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
         )
 
-    @twisted.internet.defer.inlineCallbacks
+    @defer.inlineCallbacks
     def delete_e2e_keys_by_device(self, user_id, device_id):
         yield self._simple_delete(
             table="e2e_device_keys_json",
diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/schema/delta/40/device_list_streams.sql
new file mode 100644
index 0000000000..54841b3843
--- /dev/null
+++ b/synapse/storage/schema/delta/40/device_list_streams.sql
@@ -0,0 +1,59 @@
+/* Copyright 2017 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Cache of remote devices.
+CREATE TABLE device_lists_remote_cache (
+    user_id TEXT NOT NULL,
+    device_id TEXT NOT NULL,
+    content TEXT NOT NULL
+);
+
+CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id);
+
+
+-- The last update we got for a user. Empty if we're not receiving updates for
+-- that user.
+CREATE TABLE device_lists_remote_extremeties (
+    user_id TEXT NOT NULL,
+    stream_id TEXT NOT NULL
+);
+
+CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id);
+
+
+-- Stream of device lists updates. Includes both local and remotes
+CREATE TABLE device_lists_stream (
+    stream_id BIGINT NOT NULL,
+    user_id TEXT NOT NULL,
+    device_id TEXT NOT NULL
+);
+
+CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id);
+
+
+-- The stream of updates to send to other servers. We keep at least one row
+-- per user that was sent so that the prev_id for any new updates can be
+-- calculated
+CREATE TABLE device_lists_outbound_pokes (
+    destination TEXT NOT NULL,
+    stream_id BIGINT NOT NULL,
+    user_id TEXT NOT NULL,
+    device_id TEXT NOT NULL,
+    sent BOOLEAN NOT NULL,
+    ts BIGINT NOT NULL  -- So that in future we can clear out pokes to dead servers
+);
+
+CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id);
+CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id);
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 4d44c3d4ca..91a59b0bae 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -44,6 +44,7 @@ class EventSources(object):
     def get_current_token(self):
         push_rules_key, _ = self.store.get_push_rules_stream_token()
         to_device_key = self.store.get_to_device_stream_token()
+        device_list_key = self.store.get_device_stream_token()
 
         token = StreamToken(
             room_key=(
@@ -63,6 +64,7 @@ class EventSources(object):
             ),
             push_rules_key=push_rules_key,
             to_device_key=to_device_key,
+            device_list_key=device_list_key,
         )
         defer.returnValue(token)
 
@@ -70,6 +72,7 @@ class EventSources(object):
     def get_current_token_for_room(self, room_id):
         push_rules_key, _ = self.store.get_push_rules_stream_token()
         to_device_key = self.store.get_to_device_stream_token()
+        device_list_key = self.store.get_device_stream_token()
 
         token = StreamToken(
             room_key=(
@@ -89,5 +92,6 @@ class EventSources(object):
             ),
             push_rules_key=push_rules_key,
             to_device_key=to_device_key,
+            device_list_key=device_list_key,
         )
         defer.returnValue(token)
diff --git a/synapse/types.py b/synapse/types.py
index 3a3ab21d17..9666f9d73f 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -158,6 +158,7 @@ class StreamToken(
         "account_data_key",
         "push_rules_key",
         "to_device_key",
+        "device_list_key",
     ))
 ):
     _SEPARATOR = "_"
@@ -195,6 +196,7 @@ class StreamToken(
             or (int(other.account_data_key) < int(self.account_data_key))
             or (int(other.push_rules_key) < int(self.push_rules_key))
             or (int(other.to_device_key) < int(self.to_device_key))
+            or (int(other.device_list_key) < int(self.device_list_key))
         )
 
     def copy_and_advance(self, key, new_value):