diff --git a/changelog.d/7010.misc b/changelog.d/7010.misc
new file mode 100644
index 0000000000..4ba1f6cdf8
--- /dev/null
+++ b/changelog.d/7010.misc
@@ -0,0 +1 @@
+Change device list streams to have one row per ID.
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index b2c764bfe8..cdc078cf11 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -676,8 +676,9 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler):
elif stream_name == "device_lists":
all_room_ids = set()
for row in rows:
- room_ids = await self.store.get_rooms_for_user(row.user_id)
- all_room_ids.update(room_ids)
+ if row.entity.startswith("@"):
+ room_ids = await self.store.get_rooms_for_user(row.entity)
+ all_room_ids.update(room_ids)
self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
elif stream_name == "presence":
await self.presence_handler.process_replication_rows(token, rows)
@@ -774,7 +775,10 @@ class FederationSenderHandler(object):
# ... as well as device updates and messages
elif stream_name == DeviceListsStream.NAME:
- hosts = {row.destination for row in rows}
+ # The entities are either user IDs (starting with '@') whose devices
+ # have changed, or remote servers that we need to tell about
+ # changes.
+ hosts = {row.entity for row in rows if not row.entity.startswith("@")}
for host in hosts:
self.federation_sender.send_device_messages(host)
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 1c77687eea..23b1650e41 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -29,7 +29,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
self.hs = hs
self._device_list_id_gen = SlavedIdTracker(
- db_conn, "device_lists_stream", "stream_id"
+ db_conn,
+ "device_lists_stream",
+ "stream_id",
+ extra_tables=[
+ ("user_signature_stream", "stream_id"),
+ ("device_lists_outbound_pokes", "stream_id"),
+ ],
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
@@ -55,23 +61,27 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
def process_replication_rows(self, stream_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
- for row in rows:
- self._invalidate_caches_for_devices(token, row.user_id, row.destination)
+ self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME:
+ self._device_list_id_gen.advance(token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)
- def _invalidate_caches_for_devices(self, token, user_id, destination):
- self._device_list_stream_cache.entity_has_changed(user_id, token)
-
- if destination:
- self._device_list_federation_stream_cache.entity_has_changed(
- destination, token
- )
+ def _invalidate_caches_for_devices(self, token, rows):
+ for row in rows:
+ # The entities are either user IDs (starting with '@') whose devices
+ # have changed, or remote servers that we need to tell about
+ # changes.
+ if row.entity.startswith("@"):
+ self._device_list_stream_cache.entity_has_changed(row.entity, token)
+ self.get_cached_devices_for_user.invalidate((row.entity,))
+ self._get_cached_user_device.invalidate_many((row.entity,))
+ self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
- self.get_cached_devices_for_user.invalidate((user_id,))
- self._get_cached_user_device.invalidate_many((user_id,))
- self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
+ else:
+ self._device_list_federation_stream_cache.entity_has_changed(
+ row.entity, token
+ )
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 208e8a667b..7a8b6e9df1 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -94,9 +94,13 @@ PublicRoomsStreamRow = namedtuple(
"network_id", # str, optional
),
)
-DeviceListsStreamRow = namedtuple(
- "DeviceListsStreamRow", ("user_id", "destination") # str # str
-)
+
+
+@attr.s
+class DeviceListsStreamRow:
+ entity = attr.ib(type=str)
+
+
ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
TagAccountDataStreamRow = namedtuple(
"TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
@@ -363,7 +367,8 @@ class PublicRoomsStream(Stream):
class DeviceListsStream(Stream):
- """Someone added/changed/removed a device
+ """Either a user has updated their devices or a remote server needs to be
+ told about a device update.
"""
NAME = "device_lists"
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index acca079f23..649e835303 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -144,7 +144,10 @@ class DataStore(
db_conn,
"device_lists_stream",
"stream_id",
- extra_tables=[("user_signature_stream", "stream_id")],
+ extra_tables=[
+ ("user_signature_stream", "stream_id"),
+ ("device_lists_outbound_pokes", "stream_id"),
+ ],
)
self._cross_signing_id_gen = StreamIdGenerator(
db_conn, "e2e_cross_signing_keys", "stream_id"
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index d55733a4cd..4c19c02bbc 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import List, Tuple
from six import iteritems
@@ -31,7 +32,7 @@ from synapse.logging.opentracing import (
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import Database
+from synapse.storage.database import Database, LoggingTransaction
from synapse.types import Collection, get_verify_key_from_cross_signing_key
from synapse.util.caches.descriptors import (
Cache,
@@ -112,23 +113,13 @@ class DeviceWorkerStore(SQLBaseStore):
if not has_changed:
return now_stream_id, []
- # We retrieve n+1 devices from the list of outbound pokes where n is
- # our outbound device update limit. We then check if the very last
- # device has the same stream_id as the second-to-last device. If so,
- # then we ignore all devices with that stream_id and only send the
- # devices with a lower stream_id.
- #
- # If when culling the list we end up with no devices afterwards, we
- # consider the device update to be too large, and simply skip the
- # stream_id; the rationale being that such a large device list update
- # is likely an error.
updates = yield self.db.runInteraction(
"get_device_updates_by_remote",
self._get_device_updates_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
- limit + 1,
+ limit,
)
# Return an empty list if there are no updates
@@ -166,14 +157,6 @@ class DeviceWorkerStore(SQLBaseStore):
"device_id": verify_key.version,
}
- # if we have exceeded the limit, we need to exclude any results with the
- # same stream_id as the last row.
- if len(updates) > limit:
- stream_id_cutoff = updates[-1][2]
- now_stream_id = stream_id_cutoff - 1
- else:
- stream_id_cutoff = None
-
# Perform the equivalent of a GROUP BY
#
# Iterate through the updates list and copy non-duplicate
@@ -192,10 +175,6 @@ class DeviceWorkerStore(SQLBaseStore):
query_map = {}
cross_signing_keys_by_user = {}
for user_id, device_id, update_stream_id, update_context in updates:
- if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
- # Stop processing updates
- break
-
if (
user_id in master_key_by_user
and device_id == master_key_by_user[user_id]["device_id"]
@@ -218,17 +197,6 @@ class DeviceWorkerStore(SQLBaseStore):
if update_stream_id > previous_update_stream_id:
query_map[key] = (update_stream_id, update_context)
- # If we didn't find any updates with a stream_id lower than the cutoff, it
- # means that there are more than limit updates all of which have the same
- # steam_id.
-
- # That should only happen if a client is spamming the server with new
- # devices, in which case E2E isn't going to work well anyway. We'll just
- # skip that stream_id and return an empty list, and continue with the next
- # stream_id next time.
- if not query_map and not cross_signing_keys_by_user:
- return stream_id_cutoff, []
-
results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
@@ -607,21 +575,26 @@ class DeviceWorkerStore(SQLBaseStore):
else:
return set()
- def get_all_device_list_changes_for_remotes(self, from_key, to_key):
- """Return a list of `(stream_id, user_id, destination)` which is the
- combined list of changes to devices, and which destinations need to be
- poked. `destination` may be None if no destinations need to be poked.
+ async def get_all_device_list_changes_for_remotes(
+ self, from_key: int, to_key: int
+ ) -> List[Tuple[int, str]]:
+ """Return a list of `(stream_id, entity)` which is the combined list of
+ changes to devices and which destinations need to be poked. Entity is
+ either a user ID (starting with '@') or a remote destination.
"""
- # We do a group by here as there can be a large number of duplicate
- # entries, since we throw away device IDs.
+
+ # This query Does The Right Thing where it'll correctly apply the
+ # bounds to the inner queries.
sql = """
- SELECT MAX(stream_id) AS stream_id, user_id, destination
- FROM device_lists_stream
- LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
+ SELECT stream_id, entity FROM (
+ SELECT stream_id, user_id AS entity FROM device_lists_stream
+ UNION ALL
+ SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+ ) AS e
WHERE ? < stream_id AND stream_id <= ?
- GROUP BY user_id, destination
"""
- return self.db.execute(
+
+ return await self.db.execute(
"get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
)
@@ -1017,29 +990,49 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
- with self._device_list_id_gen.get_next() as stream_id:
+ if not device_ids:
+ return
+
+ with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
yield self.db.runInteraction(
- "add_device_change_to_streams",
- self._add_device_change_txn,
+ "add_device_change_to_stream",
+ self._add_device_change_to_stream_txn,
+ user_id,
+ device_ids,
+ stream_ids,
+ )
+
+ if not hosts:
+ return stream_ids[-1]
+
+ context = get_active_span_text_map()
+ with self._device_list_id_gen.get_next_mult(
+ len(hosts) * len(device_ids)
+ ) as stream_ids:
+ yield self.db.runInteraction(
+ "add_device_outbound_poke_to_stream",
+ self._add_device_outbound_poke_to_stream_txn,
user_id,
device_ids,
hosts,
- stream_id,
+ stream_ids,
+ context,
)
- return stream_id
- def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
- now = self._clock.time_msec()
+ return stream_ids[-1]
+ def _add_device_change_to_stream_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_ids: Collection[str],
+ stream_ids: List[str],
+ ):
txn.call_after(
- self._device_list_stream_cache.entity_has_changed, user_id, stream_id
+ self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
)
- for host in hosts:
- txn.call_after(
- self._device_list_federation_stream_cache.entity_has_changed,
- host,
- stream_id,
- )
+
+ min_stream_id = stream_ids[0]
# Delete older entries in the table, as we really only care about
# when the latest change happened.
@@ -1048,7 +1041,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ?
""",
- [(user_id, device_id, stream_id) for device_id in device_ids],
+ [(user_id, device_id, min_stream_id) for device_id in device_ids],
)
self.db.simple_insert_many_txn(
@@ -1056,11 +1049,22 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
table="device_lists_stream",
values=[
{"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
- for device_id in device_ids
+ for stream_id, device_id in zip(stream_ids, device_ids)
],
)
- context = get_active_span_text_map()
+ def _add_device_outbound_poke_to_stream_txn(
+ self, txn, user_id, device_ids, hosts, stream_ids, context,
+ ):
+ for host in hosts:
+ txn.call_after(
+ self._device_list_federation_stream_cache.entity_has_changed,
+ host,
+ stream_ids[-1],
+ )
+
+ now = self._clock.time_msec()
+ next_stream_id = iter(stream_ids)
self.db.simple_insert_many_txn(
txn,
@@ -1068,7 +1072,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values=[
{
"destination": destination,
- "stream_id": stream_id,
+ "stream_id": next(next_stream_id),
"user_id": user_id,
"device_id": device_id,
"sent": False,
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 6f8d990959..c2539b353a 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -88,51 +88,6 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
# Check original device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates)
- @defer.inlineCallbacks
- def test_get_device_updates_by_remote_limited(self):
- # Test breaking the update limit in 1, 101, and 1 device_id segments
-
- # first add one device
- device_ids1 = ["device_id0"]
- yield self.store.add_device_change_to_streams(
- "user_id", device_ids1, ["someotherhost"]
- )
-
- # then add 101
- device_ids2 = ["device_id" + str(i + 1) for i in range(101)]
- yield self.store.add_device_change_to_streams(
- "user_id", device_ids2, ["someotherhost"]
- )
-
- # then one more
- device_ids3 = ["newdevice"]
- yield self.store.add_device_change_to_streams(
- "user_id", device_ids3, ["someotherhost"]
- )
-
- #
- # now read them back.
- #
-
- # first we should get a single update
- now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
- "someotherhost", -1, limit=100
- )
- self._check_devices_in_updates(device_ids1, device_updates)
-
- # Then we should get an empty list back as the 101 devices broke the limit
- now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
- "someotherhost", now_stream_id, limit=100
- )
- self.assertEqual(len(device_updates), 0)
-
- # The 101 devices should've been cleared, so we should now just get one device
- # update
- now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
- "someotherhost", now_stream_id, limit=100
- )
- self._check_devices_in_updates(device_ids3, device_updates)
-
def _check_devices_in_updates(self, expected_device_ids, device_updates):
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))
|