From c974116f197d211ba9b42159fe61cfd5957411b5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 26 Jan 2017 16:06:54 +0000 Subject: Implement device key caching over federation --- synapse/handlers/device.py | 85 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 73 insertions(+), 12 deletions(-) (limited to 'synapse/handlers/device.py') diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index ba4c48d590..2d66b3721a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -15,6 +15,7 @@ from synapse.api import errors from synapse.util import stringutils +from synapse.util.async import Linearizer from synapse.types import get_domain_from_id from twisted.internet import defer from ._base import BaseHandler @@ -28,8 +29,18 @@ class DeviceHandler(BaseHandler): def __init__(self, hs): super(DeviceHandler, self).__init__(hs) + self.hs = hs self.state = hs.get_state_handler() - self.federation = hs.get_federation_sender() + self.federation_sender = hs.get_federation_sender() + self.federation = hs.get_replication_layer() + self._remote_edue_linearizer = Linearizer(name="remote_device_list") + + self.federation.register_edu_handler( + "m.device_list_update", self._incoming_device_list_update, + ) + self.federation.register_query_handler( + "user_devices", self.on_federation_query_user_devices, + ) @defer.inlineCallbacks def check_device_registered(self, user_id, device_id, @@ -55,7 +66,7 @@ class DeviceHandler(BaseHandler): initial_device_display_name=initial_device_display_name, ) if new_device: - yield self.notify_device_update(user_id, device_id) + yield self.notify_device_update(user_id, [device_id]) defer.returnValue(device_id) # if the device id is not specified, we'll autogen one, but loop a few @@ -69,7 +80,7 @@ class DeviceHandler(BaseHandler): initial_device_display_name=initial_device_display_name, ) if new_device: - yield self.notify_device_update(user_id, device_id) + yield self.notify_device_update(user_id, [device_id]) defer.returnValue(device_id) attempts += 1 @@ -151,7 +162,7 @@ class DeviceHandler(BaseHandler): user_id=user_id, device_id=device_id ) - yield self.notify_device_update(user_id, device_id) + yield self.notify_device_update(user_id, [device_id]) @defer.inlineCallbacks def update_device(self, user_id, device_id, content): @@ -172,7 +183,7 @@ class DeviceHandler(BaseHandler): device_id, new_display_name=content.get("display_name") ) - yield self.notify_device_update(user_id, device_id) + yield self.notify_device_update(user_id, [device_id]) except errors.StoreError, e: if e.code == 404: raise errors.NotFoundError() @@ -180,26 +191,28 @@ class DeviceHandler(BaseHandler): raise @defer.inlineCallbacks - def notify_device_update(self, user_id, device_id): + def notify_device_update(self, user_id, device_ids): rooms = yield self.store.get_rooms_for_user(user_id) room_ids = [r.room_id for r in rooms] hosts = set() - for room_id in room_ids: - users = yield self.state.get_current_user_in_room(room_id) - hosts.update(get_domain_from_id(u) for u in users) - hosts.discard(self.server_name) + if self.hs.is_mine_id(user_id): + for room_id in room_ids: + users = yield self.state.get_current_user_in_room(room_id) + hosts.update(get_domain_from_id(u) for u in users) + hosts.discard(self.server_name) position = yield self.store.add_device_change_to_streams( - user_id, device_id, list(hosts) + user_id, device_ids, list(hosts) ) yield self.notifier.on_new_event( "device_list_key", position, rooms=room_ids, ) + logger.info("Sending device list update notif to: %r", hosts) for host in hosts: - self.federation.send_device_messages(host) + self.federation_sender.send_device_messages(host) @defer.inlineCallbacks def get_device_list_changes(self, user_id, room_ids, from_key): @@ -214,6 +227,54 @@ class DeviceHandler(BaseHandler): defer.returnValue(user_ids_changed) + @defer.inlineCallbacks + def _incoming_device_list_update(self, origin, edu_content): + user_id = edu_content["user_id"] + device_id = edu_content["device_id"] + stream_id = edu_content["stream_id"] + prev_ids = edu_content.get("prev_id", []) + + if get_domain_from_id(user_id) != origin: + # TODO: Raise? + return + + logger.info("Got edu: %r", edu_content) + + with (yield self._remote_edue_linearizer.queue(user_id)): + resync = True + if len(prev_ids) == 1: + extremity = yield self.store.get_device_list_remote_extremity(user_id) + logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids) + if str(extremity) == str(prev_ids[0]): + resync = False + + if resync: + result = yield self.federation.query_user_devices(origin, user_id) + stream_id = result["stream_id"] + devices = result["devices"] + yield self.store.update_remote_device_list_cache( + user_id, devices, stream_id, + ) + device_ids = [device["device_id"] for device in devices] + yield self.notify_device_update(user_id, device_ids) + else: + content = dict(edu_content) + for key in ("user_id", "device_id", "stream_id", "prev_ids"): + content.pop(key, None) + yield self.store.update_remote_device_list_cache_entry( + user_id, device_id, content, stream_id, + ) + yield self.notify_device_update(user_id, [device_id]) + + @defer.inlineCallbacks + def on_federation_query_user_devices(self, user_id): + stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) + defer.returnValue({ + "user_id": user_id, + "stream_id": stream_id, + "devices": devices, + }) + def _update_device_from_client_ips(device, client_ips): ip = client_ips.get((device["user_id"], device["device_id"]), {}) -- cgit 1.4.1