diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index f0e29e9836..1647072f65 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -53,6 +53,8 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type)
)
+ self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
+ self.get_receipts_for_room.invalidate((room_id, receipt_type))
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "receipts":
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 40530632c6..eac8694e0f 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -292,20 +292,6 @@ class ReceiptsWorkerStore(SQLBaseStore):
"get_all_updated_receipts", get_all_updated_receipts_txn
)
-
-class ReceiptsStore(ReceiptsWorkerStore):
- def __init__(self, db_conn, hs):
- # We instantiate this first as the ReceiptsWorkerStore constructor
- # needs to be able to call get_max_receipt_stream_id
- self._receipts_id_gen = StreamIdGenerator(
- db_conn, "receipts_linearized", "stream_id"
- )
-
- super(ReceiptsStore, self).__init__(db_conn, hs)
-
- def get_max_receipt_stream_id(self):
- return self._receipts_id_gen.get_current_token()
-
def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
user_id):
if receipt_type != "m.read":
@@ -326,6 +312,20 @@ class ReceiptsStore(ReceiptsWorkerStore):
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
+
+class ReceiptsStore(ReceiptsWorkerStore):
+ def __init__(self, db_conn, hs):
+ # We instantiate this first as the ReceiptsWorkerStore constructor
+ # needs to be able to call get_max_receipt_stream_id
+ self._receipts_id_gen = StreamIdGenerator(
+ db_conn, "receipts_linearized", "stream_id"
+ )
+
+ super(ReceiptsStore, self).__init__(db_conn, hs)
+
+ def get_max_receipt_stream_id(self):
+ return self._receipts_id_gen.get_current_token()
+
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id):
txn.call_after(
|