summary refs log tree commit diff
path: root/synapse/replication/slave/storage/account_data.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/slave/storage/account_data.py')
-rw-r--r--synapse/replication/slave/storage/account_data.py41
1 files changed, 40 insertions, 1 deletions
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index f59b0eabbc..735c03c7eb 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -15,7 +15,10 @@
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
+from synapse.storage import DataStore
 from synapse.storage.account_data import AccountDataStore
+from synapse.storage.tags import TagsStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
 class SlavedAccountDataStore(BaseSlavedStore):
@@ -25,6 +28,14 @@ class SlavedAccountDataStore(BaseSlavedStore):
         self._account_data_id_gen = SlavedIdTracker(
             db_conn, "account_data_max_stream_id", "stream_id",
         )
+        self._account_data_stream_cache = StreamChangeCache(
+            "AccountDataAndTagsChangeCache",
+            self._account_data_id_gen.get_current_token(),
+        )
+
+    get_account_data_for_user = (
+        AccountDataStore.__dict__["get_account_data_for_user"]
+    )
 
     get_global_account_data_by_type_for_users = (
         AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
@@ -34,6 +45,16 @@ class SlavedAccountDataStore(BaseSlavedStore):
         AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
     )
 
+    get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
+
+    get_updated_tags = DataStore.get_updated_tags.__func__
+    get_updated_account_data_for_user = (
+        DataStore.get_updated_account_data_for_user.__func__
+    )
+
+    def get_max_account_data_stream_id(self):
+        return self._account_data_id_gen.get_current_token()
+
     def stream_positions(self):
         result = super(SlavedAccountDataStore, self).stream_positions()
         position = self._account_data_id_gen.get_current_token()
@@ -47,15 +68,33 @@ class SlavedAccountDataStore(BaseSlavedStore):
         if stream:
             self._account_data_id_gen.advance(int(stream["position"]))
             for row in stream["rows"]:
-                user_id, data_type = row[1:3]
+                position, user_id, data_type = row[:3]
                 self.get_global_account_data_by_type_for_user.invalidate(
                     (data_type, user_id,)
                 )
+                self.get_account_data_for_user.invalidate((user_id,))
+                self._account_data_stream_cache.entity_has_changed(
+                    user_id, position
+                )
 
         stream = result.get("room_account_data")
         if stream:
             self._account_data_id_gen.advance(int(stream["position"]))
+            for row in stream["rows"]:
+                position, user_id = row[:2]
+                self.get_account_data_for_user.invalidate((user_id,))
+                self._account_data_stream_cache.entity_has_changed(
+                    user_id, position
+                )
 
         stream = result.get("tag_account_data")
         if stream:
             self._account_data_id_gen.advance(int(stream["position"]))
+            for row in stream["rows"]:
+                position, user_id = row[:2]
+                self.get_tags_for_user.invalidate((user_id,))
+                self._account_data_stream_cache.entity_has_changed(
+                    user_id, position
+                )
+
+        return super(SlavedAccountDataStore, self).process_replication(result)