summary refs log tree commit diff
path: root/synapse/replication/slave/storage/receipts.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/slave/storage/receipts.py')
-rw-r--r--synapse/replication/slave/storage/receipts.py25
1 files changed, 24 insertions, 1 deletions
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index ec007516d0..ac9662d399 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -18,6 +18,7 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 from synapse.storage import DataStore
 from synapse.storage.receipts import ReceiptsStore
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 # So, um, we want to borrow a load of functions intended for reading from
 # a DataStore, but we don't want to take functions that either write to the
@@ -37,11 +38,28 @@ class SlavedReceiptsStore(BaseSlavedStore):
             db_conn, "receipts_linearized", "stream_id"
         )
 
+        self._receipts_stream_cache = StreamChangeCache(
+            "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
+        )
+
     get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
+    get_linearized_receipts_for_room = (
+        ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
+    )
+    _get_linearized_receipts_for_rooms = (
+        ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
+    )
+    get_last_receipt_event_id_for_user = (
+        ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
+    )
 
     get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
     get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
 
+    get_linearized_receipts_for_rooms = (
+        DataStore.get_linearized_receipts_for_rooms.__func__
+    )
+
     def stream_positions(self):
         result = super(SlavedReceiptsStore, self).stream_positions()
         result["receipts"] = self._receipts_id_gen.get_current_token()
@@ -52,10 +70,15 @@ class SlavedReceiptsStore(BaseSlavedStore):
         if stream:
             self._receipts_id_gen.advance(int(stream["position"]))
             for row in stream["rows"]:
-                room_id, receipt_type, user_id = row[1:4]
+                position, room_id, receipt_type, user_id = row[:4]
                 self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
+                self._receipts_stream_cache.entity_has_changed(room_id, position)
 
         return super(SlavedReceiptsStore, self).process_replication(result)
 
     def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
         self.get_receipts_for_user.invalidate((user_id, receipt_type))
+        self.get_linearized_receipts_for_room.invalidate_many((room_id,))
+        self.get_last_receipt_event_id_for_user.invalidate(
+            (user_id, room_id, receipt_type)
+        )