summary refs log tree commit diff
path: root/synapse/storage/receipts.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/receipts.py')
-rw-r--r--synapse/storage/receipts.py86
1 files changed, 65 insertions, 21 deletions
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 6b9d848eaa..8c26f39fbb 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -31,9 +31,29 @@ class ReceiptsStore(SQLBaseStore):
         super(ReceiptsStore, self).__init__(hs)
 
         self._receipts_stream_cache = StreamChangeCache(
-            "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token()
+            "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
         )
 
+    @cachedInlineCallbacks()
+    def get_users_with_read_receipts_in_room(self, room_id):
+        receipts = yield self.get_receipts_for_room(room_id, "m.read")
+        defer.returnValue(set(r['user_id'] for r in receipts))
+
+    def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
+                                                    user_id):
+        if receipt_type != "m.read":
+            return
+
+        # Returns an ObservableDeferred
+        res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None)
+
+        if res and res.called and user_id in res.result:
+            # We'd only be adding to the set, so no point invalidating if the
+            # user is already there
+            return
+
+        self.get_users_with_read_receipts_in_room.invalidate((room_id,))
+
     @cached(num_args=2)
     def get_receipts_for_room(self, room_id, receipt_type):
         return self._simple_select_list(
@@ -100,7 +120,7 @@ class ReceiptsStore(SQLBaseStore):
 
         defer.returnValue([ev for res in results.values() for ev in res])
 
-    @cachedInlineCallbacks(num_args=3, max_entries=5000)
+    @cachedInlineCallbacks(num_args=3, max_entries=5000, lru=True, tree=True)
     def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
         """Get receipts for a single room for sending to clients.
 
@@ -160,8 +180,8 @@ class ReceiptsStore(SQLBaseStore):
             "content": content,
         }])
 
-    @cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids",
-                num_args=3, inlineCallbacks=True)
+    @cachedList(cached_method_name="get_linearized_receipts_for_room",
+                list_name="room_ids", num_args=3, inlineCallbacks=True)
     def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
         if not room_ids:
             defer.returnValue({})
@@ -221,7 +241,7 @@ class ReceiptsStore(SQLBaseStore):
         defer.returnValue(results)
 
     def get_max_receipt_stream_id(self):
-        return self._receipts_id_gen.get_max_token()
+        return self._receipts_id_gen.get_current_token()
 
     def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
                                       user_id, event_id, data, stream_id):
@@ -229,10 +249,14 @@ class ReceiptsStore(SQLBaseStore):
             self.get_receipts_for_room.invalidate, (room_id, receipt_type)
         )
         txn.call_after(
+            self._invalidate_get_users_with_receipts_in_room,
+            room_id, receipt_type, user_id,
+        )
+        txn.call_after(
             self.get_receipts_for_user.invalidate, (user_id, receipt_type)
         )
         # FIXME: This shouldn't invalidate the whole cache
-        txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
+        txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
 
         txn.call_after(
             self._receipts_stream_cache.entity_has_changed,
@@ -244,6 +268,17 @@ class ReceiptsStore(SQLBaseStore):
             (user_id, room_id, receipt_type)
         )
 
+        res = self._simple_select_one_txn(
+            txn,
+            table="events",
+            retcols=["topological_ordering", "stream_ordering"],
+            keyvalues={"event_id": event_id},
+            allow_none=True
+        )
+
+        topological_ordering = int(res["topological_ordering"]) if res else None
+        stream_ordering = int(res["stream_ordering"]) if res else None
+
         # We don't want to clobber receipts for more recent events, so we
         # have to compare orderings of existing receipts
         sql = (
@@ -255,16 +290,7 @@ class ReceiptsStore(SQLBaseStore):
         txn.execute(sql, (room_id, receipt_type, user_id))
         results = txn.fetchall()
 
-        if results:
-            res = self._simple_select_one_txn(
-                txn,
-                table="events",
-                retcols=["topological_ordering", "stream_ordering"],
-                keyvalues={"event_id": event_id},
-            )
-            topological_ordering = int(res["topological_ordering"])
-            stream_ordering = int(res["stream_ordering"])
-
+        if results and topological_ordering:
             for to, so, _ in results:
                 if int(to) > topological_ordering:
                     return False
@@ -294,6 +320,14 @@ class ReceiptsStore(SQLBaseStore):
             }
         )
 
+        if receipt_type == "m.read" and topological_ordering:
+            self._remove_old_push_actions_before_txn(
+                txn,
+                room_id=room_id,
+                user_id=user_id,
+                topological_ordering=topological_ordering,
+            )
+
         return True
 
     @defer.inlineCallbacks
@@ -346,7 +380,7 @@ class ReceiptsStore(SQLBaseStore):
             room_id, receipt_type, user_id, event_ids, data
         )
 
-        max_persisted_id = self._stream_id_gen.get_max_token()
+        max_persisted_id = self._stream_id_gen.get_current_token()
 
         defer.returnValue((stream_id, max_persisted_id))
 
@@ -364,10 +398,14 @@ class ReceiptsStore(SQLBaseStore):
             self.get_receipts_for_room.invalidate, (room_id, receipt_type)
         )
         txn.call_after(
+            self._invalidate_get_users_with_receipts_in_room,
+            room_id, receipt_type, user_id,
+        )
+        txn.call_after(
             self.get_receipts_for_user.invalidate, (user_id, receipt_type)
         )
         # FIXME: This shouldn't invalidate the whole cache
-        txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
+        txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
 
         self._simple_delete_txn(
             txn,
@@ -390,16 +428,22 @@ class ReceiptsStore(SQLBaseStore):
             }
         )
 
-    def get_all_updated_receipts(self, last_id, current_id, limit):
+    def get_all_updated_receipts(self, last_id, current_id, limit=None):
+        if last_id == current_id:
+            return defer.succeed([])
+
         def get_all_updated_receipts_txn(txn):
             sql = (
                 "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
                 " FROM receipts_linearized"
                 " WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC"
-                " LIMIT ?"
             )
-            txn.execute(sql, (last_id, current_id, limit))
+            args = [last_id, current_id]
+            if limit is not None:
+                sql += " LIMIT ?"
+                args.append(limit)
+            txn.execute(sql, args)
 
             return txn.fetchall()
         return self.runInteraction(