summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/data_stores/main/__init__.py5
-rw-r--r--synapse/storage/data_stores/main/devices.py132
2 files changed, 72 insertions, 65 deletions
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index acca079f23..649e835303 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -144,7 +144,10 @@ class DataStore(
             db_conn,
             "device_lists_stream",
             "stream_id",
-            extra_tables=[("user_signature_stream", "stream_id")],
+            extra_tables=[
+                ("user_signature_stream", "stream_id"),
+                ("device_lists_outbound_pokes", "stream_id"),
+            ],
         )
         self._cross_signing_id_gen = StreamIdGenerator(
             db_conn, "e2e_cross_signing_keys", "stream_id"
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index d55733a4cd..4c19c02bbc 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -15,6 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import List, Tuple
 
 from six import iteritems
 
@@ -31,7 +32,7 @@ from synapse.logging.opentracing import (
 )
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import Database
+from synapse.storage.database import Database, LoggingTransaction
 from synapse.types import Collection, get_verify_key_from_cross_signing_key
 from synapse.util.caches.descriptors import (
     Cache,
@@ -112,23 +113,13 @@ class DeviceWorkerStore(SQLBaseStore):
         if not has_changed:
             return now_stream_id, []
 
-        # We retrieve n+1 devices from the list of outbound pokes where n is
-        # our outbound device update limit. We then check if the very last
-        # device has the same stream_id as the second-to-last device. If so,
-        # then we ignore all devices with that stream_id and only send the
-        # devices with a lower stream_id.
-        #
-        # If when culling the list we end up with no devices afterwards, we
-        # consider the device update to be too large, and simply skip the
-        # stream_id; the rationale being that such a large device list update
-        # is likely an error.
         updates = yield self.db.runInteraction(
             "get_device_updates_by_remote",
             self._get_device_updates_by_remote_txn,
             destination,
             from_stream_id,
             now_stream_id,
-            limit + 1,
+            limit,
         )
 
         # Return an empty list if there are no updates
@@ -166,14 +157,6 @@ class DeviceWorkerStore(SQLBaseStore):
                     "device_id": verify_key.version,
                 }
 
-        # if we have exceeded the limit, we need to exclude any results with the
-        # same stream_id as the last row.
-        if len(updates) > limit:
-            stream_id_cutoff = updates[-1][2]
-            now_stream_id = stream_id_cutoff - 1
-        else:
-            stream_id_cutoff = None
-
         # Perform the equivalent of a GROUP BY
         #
         # Iterate through the updates list and copy non-duplicate
@@ -192,10 +175,6 @@ class DeviceWorkerStore(SQLBaseStore):
         query_map = {}
         cross_signing_keys_by_user = {}
         for user_id, device_id, update_stream_id, update_context in updates:
-            if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
-                # Stop processing updates
-                break
-
             if (
                 user_id in master_key_by_user
                 and device_id == master_key_by_user[user_id]["device_id"]
@@ -218,17 +197,6 @@ class DeviceWorkerStore(SQLBaseStore):
                 if update_stream_id > previous_update_stream_id:
                     query_map[key] = (update_stream_id, update_context)
 
-        # If we didn't find any updates with a stream_id lower than the cutoff, it
-        # means that there are more than limit updates all of which have the same
-        # steam_id.
-
-        # That should only happen if a client is spamming the server with new
-        # devices, in which case E2E isn't going to work well anyway. We'll just
-        # skip that stream_id and return an empty list, and continue with the next
-        # stream_id next time.
-        if not query_map and not cross_signing_keys_by_user:
-            return stream_id_cutoff, []
-
         results = yield self._get_device_update_edus_by_remote(
             destination, from_stream_id, query_map
         )
@@ -607,21 +575,26 @@ class DeviceWorkerStore(SQLBaseStore):
         else:
             return set()
 
-    def get_all_device_list_changes_for_remotes(self, from_key, to_key):
-        """Return a list of `(stream_id, user_id, destination)` which is the
-        combined list of changes to devices, and which destinations need to be
-        poked. `destination` may be None if no destinations need to be poked.
+    async def get_all_device_list_changes_for_remotes(
+        self, from_key: int, to_key: int
+    ) -> List[Tuple[int, str]]:
+        """Return a list of `(stream_id, entity)` which is the combined list of
+        changes to devices and which destinations need to be poked. Entity is
+        either a user ID (starting with '@') or a remote destination.
         """
-        # We do a group by here as there can be a large number of duplicate
-        # entries, since we throw away device IDs.
+
+        # This query Does The Right Thing where it'll correctly apply the
+        # bounds to the inner queries.
         sql = """
-            SELECT MAX(stream_id) AS stream_id, user_id, destination
-            FROM device_lists_stream
-            LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
+            SELECT stream_id, entity FROM (
+                SELECT stream_id, user_id AS entity FROM device_lists_stream
+                UNION ALL
+                SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+            ) AS e
             WHERE ? < stream_id AND stream_id <= ?
-            GROUP BY user_id, destination
         """
-        return self.db.execute(
+
+        return await self.db.execute(
             "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
         )
 
@@ -1017,29 +990,49 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         """Persist that a user's devices have been updated, and which hosts
         (if any) should be poked.
         """
-        with self._device_list_id_gen.get_next() as stream_id:
+        if not device_ids:
+            return
+
+        with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
             yield self.db.runInteraction(
-                "add_device_change_to_streams",
-                self._add_device_change_txn,
+                "add_device_change_to_stream",
+                self._add_device_change_to_stream_txn,
+                user_id,
+                device_ids,
+                stream_ids,
+            )
+
+        if not hosts:
+            return stream_ids[-1]
+
+        context = get_active_span_text_map()
+        with self._device_list_id_gen.get_next_mult(
+            len(hosts) * len(device_ids)
+        ) as stream_ids:
+            yield self.db.runInteraction(
+                "add_device_outbound_poke_to_stream",
+                self._add_device_outbound_poke_to_stream_txn,
                 user_id,
                 device_ids,
                 hosts,
-                stream_id,
+                stream_ids,
+                context,
             )
-        return stream_id
 
-    def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
-        now = self._clock.time_msec()
+        return stream_ids[-1]
 
+    def _add_device_change_to_stream_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_ids: Collection[str],
+        stream_ids: List[str],
+    ):
         txn.call_after(
-            self._device_list_stream_cache.entity_has_changed, user_id, stream_id
+            self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
         )
-        for host in hosts:
-            txn.call_after(
-                self._device_list_federation_stream_cache.entity_has_changed,
-                host,
-                stream_id,
-            )
+
+        min_stream_id = stream_ids[0]
 
         # Delete older entries in the table, as we really only care about
         # when the latest change happened.
@@ -1048,7 +1041,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             DELETE FROM device_lists_stream
             WHERE user_id = ? AND device_id = ? AND stream_id < ?
             """,
-            [(user_id, device_id, stream_id) for device_id in device_ids],
+            [(user_id, device_id, min_stream_id) for device_id in device_ids],
         )
 
         self.db.simple_insert_many_txn(
@@ -1056,11 +1049,22 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             table="device_lists_stream",
             values=[
                 {"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
-                for device_id in device_ids
+                for stream_id, device_id in zip(stream_ids, device_ids)
             ],
         )
 
-        context = get_active_span_text_map()
+    def _add_device_outbound_poke_to_stream_txn(
+        self, txn, user_id, device_ids, hosts, stream_ids, context,
+    ):
+        for host in hosts:
+            txn.call_after(
+                self._device_list_federation_stream_cache.entity_has_changed,
+                host,
+                stream_ids[-1],
+            )
+
+        now = self._clock.time_msec()
+        next_stream_id = iter(stream_ids)
 
         self.db.simple_insert_many_txn(
             txn,
@@ -1068,7 +1072,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             values=[
                 {
                     "destination": destination,
-                    "stream_id": stream_id,
+                    "stream_id": next(next_stream_id),
                     "user_id": user_id,
                     "device_id": device_id,
                     "sent": False,