summary refs log tree commit diff
path: root/synapse/storage/databases/main/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/devices.py')
-rw-r--r--synapse/storage/databases/main/devices.py81
1 files changed, 71 insertions, 10 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 979dd4e17e..aa58c2adc3 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -13,7 +13,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import abc
 import logging
 from typing import (
     TYPE_CHECKING,
@@ -39,6 +38,8 @@ from synapse.logging.opentracing import (
     whitelisted_homeserver,
 )
 from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -49,6 +50,11 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.types import Cursor
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdGenerator,
+    AbstractStreamIdTracker,
+    StreamIdGenerator,
+)
 from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
 from synapse.util import json_decoder, json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
@@ -80,9 +86,32 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     ):
         super().__init__(database, db_conn, hs)
 
+        if hs.config.worker.worker_app is None:
+            self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+                db_conn,
+                "device_lists_stream",
+                "stream_id",
+                extra_tables=[
+                    ("user_signature_stream", "stream_id"),
+                    ("device_lists_outbound_pokes", "stream_id"),
+                    ("device_lists_changes_in_room", "stream_id"),
+                ],
+            )
+        else:
+            self._device_list_id_gen = SlavedIdTracker(
+                db_conn,
+                "device_lists_stream",
+                "stream_id",
+                extra_tables=[
+                    ("user_signature_stream", "stream_id"),
+                    ("device_lists_outbound_pokes", "stream_id"),
+                    ("device_lists_changes_in_room", "stream_id"),
+                ],
+            )
+
         # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a
         # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker).
-        device_list_max = self._device_list_id_gen.get_current_token()  # type: ignore[attr-defined]
+        device_list_max = self._device_list_id_gen.get_current_token()
         device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
             db_conn,
             "device_lists_stream",
@@ -136,6 +165,39 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
                 self._prune_old_outbound_device_pokes, 60 * 60 * 1000
             )
 
+    def process_replication_rows(
+        self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
+    ) -> None:
+        if stream_name == DeviceListsStream.NAME:
+            self._device_list_id_gen.advance(instance_name, token)
+            self._invalidate_caches_for_devices(token, rows)
+        elif stream_name == UserSignatureStream.NAME:
+            self._device_list_id_gen.advance(instance_name, token)
+            for row in rows:
+                self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
+
+    def _invalidate_caches_for_devices(
+        self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
+    ) -> None:
+        for row in rows:
+            # The entities are either user IDs (starting with '@') whose devices
+            # have changed, or remote servers that we need to tell about
+            # changes.
+            if row.entity.startswith("@"):
+                self._device_list_stream_cache.entity_has_changed(row.entity, token)
+                self.get_cached_devices_for_user.invalidate((row.entity,))
+                self._get_cached_user_device.invalidate((row.entity,))
+                self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
+
+            else:
+                self._device_list_federation_stream_cache.entity_has_changed(
+                    row.entity, token
+                )
+
+    def get_device_stream_token(self) -> int:
+        return self._device_list_id_gen.get_current_token()
+
     async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
         """Retrieve number of all devices of given users.
         Only returns number of devices that are not marked as hidden.
@@ -677,11 +739,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             },
         )
 
-    @abc.abstractmethod
-    def get_device_stream_token(self) -> int:
-        """Get the current stream id from the _device_list_id_gen"""
-        ...
-
     @trace
     @cancellable
     async def get_user_devices_from_cache(
@@ -1481,6 +1538,10 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
 
 
 class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
+    # Because we have write access, this will be a StreamIdGenerator
+    # (see DeviceWorkerStore.__init__)
+    _device_list_id_gen: AbstractStreamIdGenerator
+
     def __init__(
         self,
         database: DatabasePool,
@@ -1805,7 +1866,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 context,
             )
 
-        async with self._device_list_id_gen.get_next_mult(  # type: ignore[attr-defined]
+        async with self._device_list_id_gen.get_next_mult(
             len(device_ids)
         ) as stream_ids:
             await self.db_pool.runInteraction(
@@ -2044,7 +2105,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 [],
             )
 
-        async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:  # type: ignore[attr-defined]
+        async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
             return await self.db_pool.runInteraction(
                 "add_device_list_outbound_pokes",
                 add_device_list_outbound_pokes_txn,
@@ -2058,7 +2119,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         updates during partial joins.
         """
 
-        async with self._device_list_id_gen.get_next() as stream_id:  # type: ignore[attr-defined]
+        async with self._device_list_id_gen.get_next() as stream_id:
             await self.db_pool.simple_upsert(
                 table="device_lists_remote_pending",
                 keyvalues={