diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 5eeca6165d..40187496e2 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -164,24 +164,22 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
prefilled_cache=user_signature_stream_prefill,
)
- self._device_list_federation_stream_cache = None
- if hs.should_send_federation():
- (
- device_list_federation_prefill,
- device_list_federation_list_id,
- ) = self.db_pool.get_cache_dict(
- db_conn,
- "device_lists_outbound_pokes",
- entity_column="destination",
- stream_column="stream_id",
- max_value=device_list_max,
- limit=10000,
- )
- self._device_list_federation_stream_cache = StreamChangeCache(
- "DeviceListFederationStreamChangeCache",
- device_list_federation_list_id,
- prefilled_cache=device_list_federation_prefill,
- )
+ (
+ device_list_federation_prefill,
+ device_list_federation_list_id,
+ ) = self.db_pool.get_cache_dict(
+ db_conn,
+ "device_lists_outbound_pokes",
+ entity_column="destination",
+ stream_column="stream_id",
+ max_value=device_list_max,
+ limit=10000,
+ )
+ self._device_list_federation_stream_cache = StreamChangeCache(
+ "DeviceListFederationStreamChangeCache",
+ device_list_federation_list_id,
+ prefilled_cache=device_list_federation_prefill,
+ )
if hs.config.worker.run_background_tasks:
self._clock.looping_call(
@@ -209,29 +207,22 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
) -> None:
for row in rows:
if row.is_signature:
- self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
+ self._user_signature_stream_cache.entity_has_changed(row.entity, token)
continue
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
- if not row.hosts_calculated:
- self._device_list_stream_cache.entity_has_changed(row.user_id, token)
- self.get_cached_devices_for_user.invalidate((row.user_id,))
- self._get_cached_user_device.invalidate((row.user_id,))
- self.get_device_list_last_stream_id_for_remote.invalidate(
- (row.user_id,)
- )
+ if row.entity.startswith("@"):
+ self._device_list_stream_cache.entity_has_changed(row.entity, token)
+ self.get_cached_devices_for_user.invalidate((row.entity,))
+ self._get_cached_user_device.invalidate((row.entity,))
+ self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
- def device_lists_outbound_pokes_have_changed(
- self, destinations: StrCollection, token: int
- ) -> None:
- assert self._device_list_federation_stream_cache is not None
-
- for destination in destinations:
- self._device_list_federation_stream_cache.entity_has_changed(
- destination, token
- )
+ else:
+ self._device_list_federation_stream_cache.entity_has_changed(
+ row.entity, token
+ )
def device_lists_in_rooms_have_changed(
self, room_ids: StrCollection, token: int
@@ -372,11 +363,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
EDU contents.
"""
now_stream_id = self.get_device_stream_token()
- if from_stream_id == now_stream_id:
- return now_stream_id, []
-
- if self._device_list_federation_stream_cache is None:
- raise Exception("Func can only be used on federation senders")
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
@@ -1032,10 +1018,10 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
# This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries.
sql = """
- SELECT stream_id, user_id, hosts FROM (
- SELECT stream_id, user_id, false AS hosts FROM device_lists_stream
+ SELECT stream_id, entity FROM (
+ SELECT stream_id, user_id AS entity FROM device_lists_stream
UNION ALL
- SELECT DISTINCT stream_id, user_id, true AS hosts FROM device_lists_outbound_pokes
+ SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
@@ -1591,14 +1577,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
get_device_list_changes_in_room_txn,
)
- async def get_destinations_for_device(self, stream_id: int) -> StrCollection:
- return await self.db_pool.simple_select_onecol(
- table="device_lists_outbound_pokes",
- keyvalues={"stream_id": stream_id},
- retcol="destination",
- desc="get_destinations_for_device",
- )
-
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(
@@ -2134,13 +2112,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
stream_ids: List[int],
context: Optional[Dict[str, str]],
) -> None:
- if self._device_list_federation_stream_cache:
- for host in hosts:
- txn.call_after(
- self._device_list_federation_stream_cache.entity_has_changed,
- host,
- stream_ids[-1],
- )
+ for host in hosts:
+ txn.call_after(
+ self._device_list_federation_stream_cache.entity_has_changed,
+ host,
+ stream_ids[-1],
+ )
now = self._clock.time_msec()
stream_id_iterator = iter(stream_ids)
|