summary refs log tree commit diff
path: root/synapse/replication/slave/storage/receipts.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/slave/storage/receipts.py')
-rw-r--r--synapse/replication/slave/storage/receipts.py40
1 files changed, 12 insertions, 28 deletions
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index b371574ece..ed12342f40 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,13 +14,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.storage.receipts import ReceiptsWorkerStore
+
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
-from synapse.storage import DataStore
-from synapse.storage.receipts import ReceiptsStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
-
 # So, um, we want to borrow a load of functions intended for reading from
 # a DataStore, but we don't want to take functions that either write to the
 # DataStore or are cached and don't have cache invalidation logic.
@@ -29,36 +28,19 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 # the method descriptor on the DataStore and chuck them into our class.
 
 
-class SlavedReceiptsStore(BaseSlavedStore):
+class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
 
     def __init__(self, db_conn, hs):
-        super(SlavedReceiptsStore, self).__init__(db_conn, hs)
-
+        # We instantiate this first as the ReceiptsWorkerStore constructor
+        # needs to be able to call get_max_receipt_stream_id
         self._receipts_id_gen = SlavedIdTracker(
             db_conn, "receipts_linearized", "stream_id"
         )
 
-        self._receipts_stream_cache = StreamChangeCache(
-            "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
-        )
-
-    get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
-    get_linearized_receipts_for_room = (
-        ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
-    )
-    _get_linearized_receipts_for_rooms = (
-        ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
-    )
-    get_last_receipt_event_id_for_user = (
-        ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
-    )
-
-    get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
-    get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
+        super(SlavedReceiptsStore, self).__init__(db_conn, hs)
 
-    get_linearized_receipts_for_rooms = (
-        DataStore.get_linearized_receipts_for_rooms.__func__
-    )
+    def get_max_receipt_stream_id(self):
+        return self._receipts_id_gen.get_current_token()
 
     def stream_positions(self):
         result = super(SlavedReceiptsStore, self).stream_positions()
@@ -67,10 +49,12 @@ class SlavedReceiptsStore(BaseSlavedStore):
 
     def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
         self.get_receipts_for_user.invalidate((user_id, receipt_type))
-        self.get_linearized_receipts_for_room.invalidate_many((room_id,))
+        self._get_linearized_receipts_for_room.invalidate_many((room_id,))
         self.get_last_receipt_event_id_for_user.invalidate(
             (user_id, room_id, receipt_type)
         )
+        self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
+        self.get_receipts_for_room.invalidate((room_id, receipt_type))
 
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "receipts":