summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/device.py141
-rw-r--r--synapse/handlers/e2e_keys.py43
-rw-r--r--synapse/handlers/sync.py28
3 files changed, 195 insertions, 17 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index aa68755936..6fefb85890 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -15,6 +15,8 @@
 
 from synapse.api import errors
 from synapse.util import stringutils
+from synapse.util.async import Linearizer
+from synapse.types import get_domain_from_id
 from twisted.internet import defer
 from ._base import BaseHandler
 
@@ -27,6 +29,21 @@ class DeviceHandler(BaseHandler):
     def __init__(self, hs):
         super(DeviceHandler, self).__init__(hs)
 
+        self.hs = hs
+        self.state = hs.get_state_handler()
+        self.federation_sender = hs.get_federation_sender()
+        self.federation = hs.get_replication_layer()
+        self._remote_edue_linearizer = Linearizer(name="remote_device_list")
+
+        self.federation.register_edu_handler(
+            "m.device_list_update", self._incoming_device_list_update,
+        )
+        self.federation.register_query_handler(
+            "user_devices", self.on_federation_query_user_devices,
+        )
+
+        hs.get_distributor().observe("user_left_room", self.user_left_room)
+
     @defer.inlineCallbacks
     def check_device_registered(self, user_id, device_id,
                                 initial_device_display_name=None):
@@ -45,29 +62,29 @@ class DeviceHandler(BaseHandler):
             str: device id (generated if none was supplied)
         """
         if device_id is not None:
-            yield self.store.store_device(
+            new_device = yield self.store.store_device(
                 user_id=user_id,
                 device_id=device_id,
                 initial_device_display_name=initial_device_display_name,
-                ignore_if_known=True,
             )
+            if new_device:
+                yield self.notify_device_update(user_id, [device_id])
             defer.returnValue(device_id)
 
         # if the device id is not specified, we'll autogen one, but loop a few
         # times in case of a clash.
         attempts = 0
         while attempts < 5:
-            try:
-                device_id = stringutils.random_string(10).upper()
-                yield self.store.store_device(
-                    user_id=user_id,
-                    device_id=device_id,
-                    initial_device_display_name=initial_device_display_name,
-                    ignore_if_known=False,
-                )
+            device_id = stringutils.random_string(10).upper()
+            new_device = yield self.store.store_device(
+                user_id=user_id,
+                device_id=device_id,
+                initial_device_display_name=initial_device_display_name,
+            )
+            if new_device:
+                yield self.notify_device_update(user_id, [device_id])
                 defer.returnValue(device_id)
-            except errors.StoreError:
-                attempts += 1
+            attempts += 1
 
         raise errors.StoreError(500, "Couldn't generate a device ID.")
 
@@ -147,6 +164,8 @@ class DeviceHandler(BaseHandler):
             user_id=user_id, device_id=device_id
         )
 
+        yield self.notify_device_update(user_id, [device_id])
+
     @defer.inlineCallbacks
     def update_device(self, user_id, device_id, content):
         """ Update the given device
@@ -166,12 +185,110 @@ class DeviceHandler(BaseHandler):
                 device_id,
                 new_display_name=content.get("display_name")
             )
+            yield self.notify_device_update(user_id, [device_id])
         except errors.StoreError, e:
             if e.code == 404:
                 raise errors.NotFoundError()
             else:
                 raise
 
+    @defer.inlineCallbacks
+    def notify_device_update(self, user_id, device_ids):
+        """Notify that a user's device(s) has changed. Pokes the notifier, and
+        remote servers if the user is local.
+        """
+        rooms = yield self.store.get_rooms_for_user(user_id)
+        room_ids = [r.room_id for r in rooms]
+
+        hosts = set()
+        if self.hs.is_mine_id(user_id):
+            for room_id in room_ids:
+                users = yield self.state.get_current_user_in_room(room_id)
+                hosts.update(get_domain_from_id(u) for u in users)
+            hosts.discard(self.server_name)
+
+        position = yield self.store.add_device_change_to_streams(
+            user_id, device_ids, list(hosts)
+        )
+
+        yield self.notifier.on_new_event(
+            "device_list_key", position, rooms=room_ids,
+        )
+
+        if hosts:
+            logger.info("Sending device list update notif to: %r", hosts)
+            for host in hosts:
+                self.federation_sender.send_device_messages(host)
+
+    @defer.inlineCallbacks
+    def _incoming_device_list_update(self, origin, edu_content):
+        user_id = edu_content["user_id"]
+        device_id = edu_content["device_id"]
+        stream_id = edu_content["stream_id"]
+        prev_ids = edu_content.get("prev_id", [])
+
+        if get_domain_from_id(user_id) != origin:
+            # TODO: Raise?
+            logger.warning("Got device list update edu for %r from %r", user_id, origin)
+            return
+
+        rooms = yield self.store.get_rooms_for_user(user_id)
+        if not rooms:
+            # We don't share any rooms with this user. Ignore update, as we
+            # probably won't get any further updates.
+            return
+
+        with (yield self._remote_edue_linearizer.queue(user_id)):
+            # If the prev id matches whats in our cache table, then we don't need
+            # to resync the users device list, otherwise we do.
+            resync = True
+            if len(prev_ids) == 1:
+                extremity = yield self.store.get_device_list_last_stream_id_for_remote(
+                    user_id
+                )
+                logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids)
+                if str(extremity) == str(prev_ids[0]):
+                    resync = False
+
+            if resync:
+                # Fetch all devices for the user.
+                result = yield self.federation.query_user_devices(origin, user_id)
+                stream_id = result["stream_id"]
+                devices = result["devices"]
+                yield self.store.update_remote_device_list_cache(
+                    user_id, devices, stream_id,
+                )
+                device_ids = [device["device_id"] for device in devices]
+                yield self.notify_device_update(user_id, device_ids)
+            else:
+                # Simply update the single device, since we know that is the only
+                # change (becuase of the single prev_id matching the current cache)
+                content = dict(edu_content)
+                for key in ("user_id", "device_id", "stream_id", "prev_ids"):
+                    content.pop(key, None)
+                yield self.store.update_remote_device_list_cache_entry(
+                    user_id, device_id, content, stream_id,
+                )
+                yield self.notify_device_update(user_id, [device_id])
+
+    @defer.inlineCallbacks
+    def on_federation_query_user_devices(self, user_id):
+        stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
+        defer.returnValue({
+            "user_id": user_id,
+            "stream_id": stream_id,
+            "devices": devices,
+        })
+
+    @defer.inlineCallbacks
+    def user_left_room(self, user, room_id):
+        user_id = user.to_string()
+        rooms = yield self.store.get_rooms_for_user(user_id)
+        if not rooms:
+            # We no longer share rooms with this user, so we'll no longer
+            # receive device updates. Mark this in DB.
+            yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
+
 
 def _update_device_from_client_ips(device, client_ips):
     ip = client_ips.get((device["user_id"], device["device_id"]), {})
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index b63a660c06..a16b9def8d 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -73,10 +73,9 @@ class E2eKeysHandler(object):
             if self.is_mine_id(user_id):
                 local_query[user_id] = device_ids
             else:
-                domain = get_domain_from_id(user_id)
-                remote_queries.setdefault(domain, {})[user_id] = device_ids
+                remote_queries[user_id] = device_ids
 
-        # do the queries
+        # Firt get local devices.
         failures = {}
         results = {}
         if local_query:
@@ -85,9 +84,42 @@ class E2eKeysHandler(object):
                 if user_id in local_query:
                     results[user_id] = keys
 
+        # Now attempt to get any remote devices from our local cache.
+        remote_queries_not_in_cache = {}
+        if remote_queries:
+            query_list = []
+            for user_id, device_ids in remote_queries.iteritems():
+                if device_ids:
+                    query_list.extend((user_id, device_id) for device_id in device_ids)
+                else:
+                    query_list.append((user_id, None))
+
+            user_ids_not_in_cache, remote_results = (
+                yield self.store.get_user_devices_from_cache(
+                    query_list
+                )
+            )
+            for user_id, devices in remote_results.iteritems():
+                user_devices = results.setdefault(user_id, {})
+                for device_id, device in devices.iteritems():
+                    keys = device.get("keys", None)
+                    device_display_name = device.get("device_display_name", None)
+                    if keys:
+                        result = dict(keys)
+                        unsigned = result.setdefault("unsigned", {})
+                        if device_display_name:
+                            unsigned["device_display_name"] = device_display_name
+                        user_devices[device_id] = result
+
+            for user_id in user_ids_not_in_cache:
+                domain = get_domain_from_id(user_id)
+                r = remote_queries_not_in_cache.setdefault(domain, {})
+                r[user_id] = remote_queries[user_id]
+
+        # Now fetch any devices that we don't have in our cache
         @defer.inlineCallbacks
         def do_remote_query(destination):
-            destination_query = remote_queries[destination]
+            destination_query = remote_queries_not_in_cache[destination]
             try:
                 limiter = yield get_retry_limiter(
                     destination, self.clock, self.store
@@ -119,7 +151,7 @@ class E2eKeysHandler(object):
 
         yield preserve_context_over_deferred(defer.gatherResults([
             preserve_fn(do_remote_query)(destination)
-            for destination in remote_queries
+            for destination in remote_queries_not_in_cache
         ]))
 
         defer.returnValue({
@@ -259,6 +291,7 @@ class E2eKeysHandler(object):
                 user_id, device_id, time_now,
                 encode_canonical_json(device_keys)
             )
+            yield self.device_handler.notify_device_update(user_id, [device_id])
 
         one_time_keys = keys.get("one_time_keys", None)
         if one_time_keys:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c880f61685..9199f20817 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -115,6 +115,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
     "invited",  # InvitedSyncResult for each invited room.
     "archived",  # ArchivedSyncResult for each archived room.
     "to_device",  # List of direct messages for the device.
+    "device_lists",  # List of user_ids whose devices have chanegd
 ])):
     __slots__ = []
 
@@ -544,6 +545,10 @@ class SyncHandler(object):
 
         yield self._generate_sync_entry_for_to_device(sync_result_builder)
 
+        device_lists = yield self._generate_sync_entry_for_device_list(
+            sync_result_builder
+        )
+
         defer.returnValue(SyncResult(
             presence=sync_result_builder.presence,
             account_data=sync_result_builder.account_data,
@@ -551,10 +556,33 @@ class SyncHandler(object):
             invited=sync_result_builder.invited,
             archived=sync_result_builder.archived,
             to_device=sync_result_builder.to_device,
+            device_lists=device_lists,
             next_batch=sync_result_builder.now_token,
         ))
 
     @defer.inlineCallbacks
+    def _generate_sync_entry_for_device_list(self, sync_result_builder):
+        user_id = sync_result_builder.sync_config.user.to_string()
+        since_token = sync_result_builder.since_token
+
+        if since_token and since_token.device_list_key:
+            rooms = yield self.store.get_rooms_for_user(user_id)
+            room_ids = set(r.room_id for r in rooms)
+
+            user_ids_changed = set()
+            changed = yield self.store.get_user_whose_devices_changed(
+                since_token.device_list_key
+            )
+            for other_user_id in changed:
+                other_rooms = yield self.store.get_rooms_for_user(other_user_id)
+                if room_ids.intersection(e.room_id for e in other_rooms):
+                    user_ids_changed.add(other_user_id)
+
+            defer.returnValue(user_ids_changed)
+        else:
+            defer.returnValue([])
+
+    @defer.inlineCallbacks
     def _generate_sync_entry_for_to_device(self, sync_result_builder):
         """Generates the portion of the sync response. Populates
         `sync_result_builder` with the result.