summary refs log tree commit diff
path: root/synapse/handlers/device.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2017-01-26 16:06:54 +0000
committerErik Johnston <erik@matrix.org>2017-01-26 16:07:24 +0000
commitc974116f197d211ba9b42159fe61cfd5957411b5 (patch)
treeb4d6e71850ba0d371e0463a5cc99439ce8072c40 /synapse/handlers/device.py
parentFix up sending of m.device_list_update edus (diff)
downloadsynapse-c974116f197d211ba9b42159fe61cfd5957411b5.tar.xz
Implement device key caching over federation
Diffstat (limited to 'synapse/handlers/device.py')
-rw-r--r--synapse/handlers/device.py85
1 files changed, 73 insertions, 12 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index ba4c48d590..2d66b3721a 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -15,6 +15,7 @@
 
 from synapse.api import errors
 from synapse.util import stringutils
+from synapse.util.async import Linearizer
 from synapse.types import get_domain_from_id
 from twisted.internet import defer
 from ._base import BaseHandler
@@ -28,8 +29,18 @@ class DeviceHandler(BaseHandler):
     def __init__(self, hs):
         super(DeviceHandler, self).__init__(hs)
 
+        self.hs = hs
         self.state = hs.get_state_handler()
-        self.federation = hs.get_federation_sender()
+        self.federation_sender = hs.get_federation_sender()
+        self.federation = hs.get_replication_layer()
+        self._remote_edue_linearizer = Linearizer(name="remote_device_list")
+
+        self.federation.register_edu_handler(
+            "m.device_list_update", self._incoming_device_list_update,
+        )
+        self.federation.register_query_handler(
+            "user_devices", self.on_federation_query_user_devices,
+        )
 
     @defer.inlineCallbacks
     def check_device_registered(self, user_id, device_id,
@@ -55,7 +66,7 @@ class DeviceHandler(BaseHandler):
                 initial_device_display_name=initial_device_display_name,
             )
             if new_device:
-                yield self.notify_device_update(user_id, device_id)
+                yield self.notify_device_update(user_id, [device_id])
             defer.returnValue(device_id)
 
         # if the device id is not specified, we'll autogen one, but loop a few
@@ -69,7 +80,7 @@ class DeviceHandler(BaseHandler):
                 initial_device_display_name=initial_device_display_name,
             )
             if new_device:
-                yield self.notify_device_update(user_id, device_id)
+                yield self.notify_device_update(user_id, [device_id])
                 defer.returnValue(device_id)
             attempts += 1
 
@@ -151,7 +162,7 @@ class DeviceHandler(BaseHandler):
             user_id=user_id, device_id=device_id
         )
 
-        yield self.notify_device_update(user_id, device_id)
+        yield self.notify_device_update(user_id, [device_id])
 
     @defer.inlineCallbacks
     def update_device(self, user_id, device_id, content):
@@ -172,7 +183,7 @@ class DeviceHandler(BaseHandler):
                 device_id,
                 new_display_name=content.get("display_name")
             )
-            yield self.notify_device_update(user_id, device_id)
+            yield self.notify_device_update(user_id, [device_id])
         except errors.StoreError, e:
             if e.code == 404:
                 raise errors.NotFoundError()
@@ -180,26 +191,28 @@ class DeviceHandler(BaseHandler):
                 raise
 
     @defer.inlineCallbacks
-    def notify_device_update(self, user_id, device_id):
+    def notify_device_update(self, user_id, device_ids):
         rooms = yield self.store.get_rooms_for_user(user_id)
         room_ids = [r.room_id for r in rooms]
 
         hosts = set()
-        for room_id in room_ids:
-            users = yield self.state.get_current_user_in_room(room_id)
-            hosts.update(get_domain_from_id(u) for u in users)
-        hosts.discard(self.server_name)
+        if self.hs.is_mine_id(user_id):
+            for room_id in room_ids:
+                users = yield self.state.get_current_user_in_room(room_id)
+                hosts.update(get_domain_from_id(u) for u in users)
+            hosts.discard(self.server_name)
 
         position = yield self.store.add_device_change_to_streams(
-            user_id, device_id, list(hosts)
+            user_id, device_ids, list(hosts)
         )
 
         yield self.notifier.on_new_event(
             "device_list_key", position, rooms=room_ids,
         )
 
+        logger.info("Sending device list update notif to: %r", hosts)
         for host in hosts:
-            self.federation.send_device_messages(host)
+            self.federation_sender.send_device_messages(host)
 
     @defer.inlineCallbacks
     def get_device_list_changes(self, user_id, room_ids, from_key):
@@ -214,6 +227,54 @@ class DeviceHandler(BaseHandler):
 
         defer.returnValue(user_ids_changed)
 
+    @defer.inlineCallbacks
+    def _incoming_device_list_update(self, origin, edu_content):
+        user_id = edu_content["user_id"]
+        device_id = edu_content["device_id"]
+        stream_id = edu_content["stream_id"]
+        prev_ids = edu_content.get("prev_id", [])
+
+        if get_domain_from_id(user_id) != origin:
+            # TODO: Raise?
+            return
+
+        logger.info("Got edu: %r", edu_content)
+
+        with (yield self._remote_edue_linearizer.queue(user_id)):
+            resync = True
+            if len(prev_ids) == 1:
+                extremity = yield self.store.get_device_list_remote_extremity(user_id)
+                logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids)
+                if str(extremity) == str(prev_ids[0]):
+                    resync = False
+
+            if resync:
+                result = yield self.federation.query_user_devices(origin, user_id)
+                stream_id = result["stream_id"]
+                devices = result["devices"]
+                yield self.store.update_remote_device_list_cache(
+                    user_id, devices, stream_id,
+                )
+                device_ids = [device["device_id"] for device in devices]
+                yield self.notify_device_update(user_id, device_ids)
+            else:
+                content = dict(edu_content)
+                for key in ("user_id", "device_id", "stream_id", "prev_ids"):
+                    content.pop(key, None)
+                yield self.store.update_remote_device_list_cache_entry(
+                    user_id, device_id, content, stream_id,
+                )
+                yield self.notify_device_update(user_id, [device_id])
+
+    @defer.inlineCallbacks
+    def on_federation_query_user_devices(self, user_id):
+        stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
+        defer.returnValue({
+            "user_id": user_id,
+            "stream_id": stream_id,
+            "devices": devices,
+        })
+
 
 def _update_device_from_client_ips(device, client_ips):
     ip = client_ips.get((device["user_id"], device["device_id"]), {})