diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index acca079f23..649e835303 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -144,7 +144,10 @@ class DataStore(
db_conn,
"device_lists_stream",
"stream_id",
- extra_tables=[("user_signature_stream", "stream_id")],
+ extra_tables=[
+ ("user_signature_stream", "stream_id"),
+ ("device_lists_outbound_pokes", "stream_id"),
+ ],
)
self._cross_signing_id_gen = StreamIdGenerator(
db_conn, "e2e_cross_signing_keys", "stream_id"
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index d55733a4cd..4c19c02bbc 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import List, Tuple
from six import iteritems
@@ -31,7 +32,7 @@ from synapse.logging.opentracing import (
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import Database
+from synapse.storage.database import Database, LoggingTransaction
from synapse.types import Collection, get_verify_key_from_cross_signing_key
from synapse.util.caches.descriptors import (
Cache,
@@ -112,23 +113,13 @@ class DeviceWorkerStore(SQLBaseStore):
if not has_changed:
return now_stream_id, []
- # We retrieve n+1 devices from the list of outbound pokes where n is
- # our outbound device update limit. We then check if the very last
- # device has the same stream_id as the second-to-last device. If so,
- # then we ignore all devices with that stream_id and only send the
- # devices with a lower stream_id.
- #
- # If when culling the list we end up with no devices afterwards, we
- # consider the device update to be too large, and simply skip the
- # stream_id; the rationale being that such a large device list update
- # is likely an error.
updates = yield self.db.runInteraction(
"get_device_updates_by_remote",
self._get_device_updates_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
- limit + 1,
+ limit,
)
# Return an empty list if there are no updates
@@ -166,14 +157,6 @@ class DeviceWorkerStore(SQLBaseStore):
"device_id": verify_key.version,
}
- # if we have exceeded the limit, we need to exclude any results with the
- # same stream_id as the last row.
- if len(updates) > limit:
- stream_id_cutoff = updates[-1][2]
- now_stream_id = stream_id_cutoff - 1
- else:
- stream_id_cutoff = None
-
# Perform the equivalent of a GROUP BY
#
# Iterate through the updates list and copy non-duplicate
@@ -192,10 +175,6 @@ class DeviceWorkerStore(SQLBaseStore):
query_map = {}
cross_signing_keys_by_user = {}
for user_id, device_id, update_stream_id, update_context in updates:
- if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
- # Stop processing updates
- break
-
if (
user_id in master_key_by_user
and device_id == master_key_by_user[user_id]["device_id"]
@@ -218,17 +197,6 @@ class DeviceWorkerStore(SQLBaseStore):
if update_stream_id > previous_update_stream_id:
query_map[key] = (update_stream_id, update_context)
- # If we didn't find any updates with a stream_id lower than the cutoff, it
- # means that there are more than limit updates all of which have the same
- # steam_id.
-
- # That should only happen if a client is spamming the server with new
- # devices, in which case E2E isn't going to work well anyway. We'll just
- # skip that stream_id and return an empty list, and continue with the next
- # stream_id next time.
- if not query_map and not cross_signing_keys_by_user:
- return stream_id_cutoff, []
-
results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
@@ -607,21 +575,26 @@ class DeviceWorkerStore(SQLBaseStore):
else:
return set()
- def get_all_device_list_changes_for_remotes(self, from_key, to_key):
- """Return a list of `(stream_id, user_id, destination)` which is the
- combined list of changes to devices, and which destinations need to be
- poked. `destination` may be None if no destinations need to be poked.
+ async def get_all_device_list_changes_for_remotes(
+ self, from_key: int, to_key: int
+ ) -> List[Tuple[int, str]]:
+ """Return a list of `(stream_id, entity)` which is the combined list of
+ changes to devices and which destinations need to be poked. Entity is
+ either a user ID (starting with '@') or a remote destination.
"""
- # We do a group by here as there can be a large number of duplicate
- # entries, since we throw away device IDs.
+
+ # This query Does The Right Thing where it'll correctly apply the
+ # bounds to the inner queries.
sql = """
- SELECT MAX(stream_id) AS stream_id, user_id, destination
- FROM device_lists_stream
- LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
+ SELECT stream_id, entity FROM (
+ SELECT stream_id, user_id AS entity FROM device_lists_stream
+ UNION ALL
+ SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+ ) AS e
WHERE ? < stream_id AND stream_id <= ?
- GROUP BY user_id, destination
"""
- return self.db.execute(
+
+ return await self.db.execute(
"get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
)
@@ -1017,29 +990,49 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
- with self._device_list_id_gen.get_next() as stream_id:
+ if not device_ids:
+ return
+
+ with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
yield self.db.runInteraction(
- "add_device_change_to_streams",
- self._add_device_change_txn,
+ "add_device_change_to_stream",
+ self._add_device_change_to_stream_txn,
+ user_id,
+ device_ids,
+ stream_ids,
+ )
+
+ if not hosts:
+ return stream_ids[-1]
+
+ context = get_active_span_text_map()
+ with self._device_list_id_gen.get_next_mult(
+ len(hosts) * len(device_ids)
+ ) as stream_ids:
+ yield self.db.runInteraction(
+ "add_device_outbound_poke_to_stream",
+ self._add_device_outbound_poke_to_stream_txn,
user_id,
device_ids,
hosts,
- stream_id,
+ stream_ids,
+ context,
)
- return stream_id
- def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
- now = self._clock.time_msec()
+ return stream_ids[-1]
+ def _add_device_change_to_stream_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_ids: Collection[str],
+ stream_ids: List[str],
+ ):
txn.call_after(
- self._device_list_stream_cache.entity_has_changed, user_id, stream_id
+ self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
)
- for host in hosts:
- txn.call_after(
- self._device_list_federation_stream_cache.entity_has_changed,
- host,
- stream_id,
- )
+
+ min_stream_id = stream_ids[0]
# Delete older entries in the table, as we really only care about
# when the latest change happened.
@@ -1048,7 +1041,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ?
""",
- [(user_id, device_id, stream_id) for device_id in device_ids],
+ [(user_id, device_id, min_stream_id) for device_id in device_ids],
)
self.db.simple_insert_many_txn(
@@ -1056,11 +1049,22 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
table="device_lists_stream",
values=[
{"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
- for device_id in device_ids
+ for stream_id, device_id in zip(stream_ids, device_ids)
],
)
- context = get_active_span_text_map()
+ def _add_device_outbound_poke_to_stream_txn(
+ self, txn, user_id, device_ids, hosts, stream_ids, context,
+ ):
+ for host in hosts:
+ txn.call_after(
+ self._device_list_federation_stream_cache.entity_has_changed,
+ host,
+ stream_ids[-1],
+ )
+
+ now = self._clock.time_msec()
+ next_stream_id = iter(stream_ids)
self.db.simple_insert_many_txn(
txn,
@@ -1068,7 +1072,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values=[
{
"destination": destination,
- "stream_id": stream_id,
+ "stream_id": next(next_stream_id),
"user_id": user_id,
"device_id": device_id,
"sent": False,
|