summary refs log tree commit diff
diff options
context:
space:
mode:
authorAndrew Morgan <andrew@amorgan.xyz>2021-11-26 16:52:44 +0000
committerAndrew Morgan <andrew@amorgan.xyz>2021-12-08 18:30:51 +0000
commitbf40bfe37fcbd7c72de0310376c1fe67696e6571 (patch)
tree31c3829d4ae0d0fe1c7395fa1df700250afda543
parentAdd docstring to add_device_change_to_streams and fix types. (diff)
downloadsynapse-bf40bfe37fcbd7c72de0310376c1fe67696e6571.tar.xz
wip
-rw-r--r--synapse/handlers/appservice.py59
-rw-r--r--synapse/storage/databases/main/appservice.py2
-rw-r--r--synapse/storage/databases/main/devices.py46
3 files changed, 103 insertions, 4 deletions
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index c92668642a..9c987d9bb5 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -58,6 +58,7 @@ class ApplicationServicesHandler:
         self._msc2409_to_device_messages_enabled = (
             hs.config.experimental.msc2409_to_device_messages_enabled
         )
+        self._msc3202_enabled = hs.config.experimental.msc3202_enabled
 
         self.current_max = 0
         self.is_processing = False
@@ -204,9 +205,9 @@ class ApplicationServicesHandler:
         Args:
             stream_key: The stream the event came from.
 
-                `stream_key` can be "typing_key", "receipt_key", "presence_key" or
-                "to_device_key". Any other value for `stream_key` will cause this function
-                to return early.
+                `stream_key` can be "typing_key", "receipt_key", "presence_key",
+                "to_device_key" or "device_list_key". Any other value fo
+                `stream_key` will cause this function to return early.
 
                 Ephemeral events will only be pushed to appservices that have opted into
                 receiving them by setting `push_ephemeral` to true in their registration
@@ -230,6 +231,7 @@ class ApplicationServicesHandler:
             "receipt_key",
             "presence_key",
             "to_device_key",
+            "device_list_key",
         ):
             return
 
@@ -253,6 +255,10 @@ class ApplicationServicesHandler:
         ):
             return
 
+        # Ignore device lists if the feature flag is not enabled
+        if stream_key == "device_list_key" and not self._msc3202_enabled:
+            return
+
         # Check whether there are any appservices which have registered to receive
         # ephemeral events.
         #
@@ -336,6 +342,20 @@ class ApplicationServicesHandler:
                             service, "to_device", new_token
                         )
 
+                    elif stream_key == "device_list_key":
+                        events = await self._handle_device_list_updates(
+                            service, new_token, users
+                        )
+                        if events:
+                            self.scheduler.submit_ephemeral_events_for_as(
+                                service, events
+                            )
+
+                        # Persist the latest handled stream token for this appservice
+                        await self.store.set_type_stream_id_for_appservice(
+                            service, "device_list", new_token
+                        )
+
     async def _handle_typing(
         self, service: ApplicationService, new_token: int
     ) -> List[JsonDict]:
@@ -541,6 +561,39 @@ class ApplicationServicesHandler:
 
         return message_payload
 
+    async def _get_device_list_updates(
+        self,
+        service: ApplicationService,
+        new_token: int,
+        users: Collection[Union[UserID, str]],
+    ) -> List[JsonDict]:
+        """
+
+
+        Args:
+            service:
+            new_token:
+            users:
+
+        Returns:
+
+        """
+        users_appservice_is_interested_in = [
+            user for user in users if service.is_interested_in_user(user)
+        ]
+
+        if not users_appservice_is_interested_in:
+            # This appservice was not interested in any of these users.
+            return []
+
+        # Fetch the last successfully processed device list update stream ID
+        # for this appservice.
+        from_key = await self.store.get_type_stream_id_for_appservice(
+            service, "device_list"
+        )
+
+        # Fetch device lists updates for each user.
+
     async def query_user_exists(self, user_id: str) -> bool:
         """Check if any application service knows this user_id exists.
 
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 68ba330432..108e919564 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -424,7 +424,7 @@ class ApplicationServiceTransactionWorkerStore(
     async def set_appservice_stream_type_pos(
         self, service: ApplicationService, stream_type: str, pos: Optional[int]
     ) -> None:
-        if stream_type not in ("read_receipt", "presence", "to_device"):
+        if stream_type not in ("read_receipt", "presence", "to_device", "device_list"):
             raise ValueError(
                 "Expected type to be a valid application stream id type, got %s"
                 % (stream_type,)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index e78631352e..4e430ef480 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -581,6 +581,52 @@ class DeviceWorkerStore(SQLBaseStore):
             changes = set()
 
             sql = """
+                SELECT DISTINCT FROM device_lists_stream
+                WHERE stream_id > ?
+                AND
+            """
+
+            for chunk in batch_iter(to_check, 100):
+                clause, args = make_in_list_sql_clause(
+                    txn.database_engine, "user_id", chunk
+                )
+                txn.execute(sql + clause, (from_key,) + tuple(args))
+                changes.update(user_id for user_id, in txn)
+
+            return changes
+
+        return await self.db_pool.runInteraction(
+            "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
+        )
+
+    async def get_all_device_list_changes_for_users(
+        self, user_ids: Iterable[str], from_key: int
+    ) -> List[JsonDict]:
+        """
+        Get a list of device updates for a collection of users between the
+        given stream ID and now.
+
+        Args:
+            user_ids: The user IDs to fetch device list updates for.
+            from_key: The minimum device list stream ID to fetch updates from, inclusive.
+
+        Returns:
+            The device list changes, ordered by ascending stream ID.
+            # TODO: Should return max_stream_id?
+        """
+        # Get set of users who *may* have changed. Users not in the returned
+        # list have definitely not changed.
+        to_check = self._device_list_stream_cache.get_entities_changed(
+            user_ids, from_key
+        )
+
+        if not to_check:
+            return []
+
+        def _get_all_device_list_changes_for_users_txn(txn):
+            changes = set()
+
+            sql = """
                 SELECT DISTINCT user_id FROM device_lists_stream
                 WHERE stream_id > ?
                 AND