summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/handlers/device.py39
-rw-r--r--synapse/rest/client/v2_alpha/keys.py2
-rw-r--r--synapse/storage/stream.py7
3 files changed, 43 insertions, 5 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 4a28d95967..4589dab409 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 from synapse.api import errors
+from synapse.api.constants import EventTypes
 from synapse.util import stringutils
 from synapse.util.async import Linearizer
 from synapse.types import get_domain_from_id
@@ -221,15 +222,45 @@ class DeviceHandler(BaseHandler):
                 self.federation_sender.send_device_messages(host)
 
     @defer.inlineCallbacks
-    def get_user_ids_changed(self, user_id, from_device_key):
+    def get_user_ids_changed(self, user_id, from_token):
         rooms = yield self.store.get_rooms_for_user(user_id)
         room_ids = set(r.room_id for r in rooms)
 
-        user_ids_changed = set()
+        # First we check if any devices have changed
         changed = yield self.store.get_user_whose_devices_changed(
-            from_device_key
+            from_token.device_list_key
         )
-        for other_user_id in changed:
+
+        # Then work out if any users have since joined
+        rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
+
+        possibly_changed = set(changed)
+        for room_id in rooms_changed:
+            # Fetch (an approximation) of the current state at the time.
+            event_rows, token = yield self.store.get_recent_event_ids_for_room(
+                room_id, end_token=from_token.room_key, limit=1,
+            )
+
+            if event_rows:
+                last_event_id = event_rows[-1]["event_id"]
+                prev_state_ids = yield self.store.get_state_ids_for_event(last_event_id)
+            else:
+                prev_state_ids = {}
+
+            current_state_ids = yield self.state.get_current_state_ids(room_id)
+
+            # If there has been any change in membership, include them in the
+            # possibly changed list. We'll check if they are joined below,
+            # and we're not toooo worried about spuriously adding users.
+            for key, event_id in current_state_ids.iteritems():
+                etype, state_key = key
+                if etype == EventTypes.Member:
+                    prev_event_id = prev_state_ids.get(key, None)
+                    if not prev_event_id or prev_event_id != event_id:
+                        possibly_changed.add(state_key)
+
+        user_ids_changed = set()
+        for other_user_id in possibly_changed:
             other_rooms = yield self.store.get_rooms_for_user(other_user_id)
             if room_ids.intersection(e.room_id for e in other_rooms):
                 user_ids_changed.add(other_user_id)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 4590efa6bf..f99b53530a 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -189,7 +189,7 @@ class KeyChangesServlet(RestServlet):
         user_id = requester.user.to_string()
 
         changed = yield self.device_handler.get_user_ids_changed(
-            user_id, from_token.device_list_key,
+            user_id, from_token,
         )
 
         defer.returnValue((200, {
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 2dc24951c4..cdc1838895 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -244,6 +244,13 @@ class StreamStore(SQLBaseStore):
 
         defer.returnValue(results)
 
+    def get_rooms_that_changed(self, room_ids, from_key):
+        from_key = RoomStreamToken.parse_stream_token(from_key).stream
+        return set(
+            room_id for room_id in room_ids
+            if self._events_stream_cache.has_entity_changed(room_id, from_key)
+        )
+
     @defer.inlineCallbacks
     def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0,
                                         order='DESC'):