summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/receipts.py79
1 files changed, 67 insertions, 12 deletions
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index cac1a5657e..fb35472ad7 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -14,7 +14,8 @@
 # limitations under the License.
 
 from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches import cache_counter, caches_by_name
 
 from twisted.internet import defer
 
@@ -54,19 +55,13 @@ class ReceiptsStore(SQLBaseStore):
                 self, room_ids, from_key
             )
 
-        results = yield defer.gatherResults(
-            [
-                self.get_linearized_receipts_for_room(
-                    room_id, to_key, from_key=from_key
-                )
-                for room_id in room_ids
-            ],
-            consumeErrors=True,
-        ).addErrback(unwrapFirstError)
+        results = yield self._get_linearized_receipts_for_rooms(
+            room_ids, to_key, from_key=from_key
+        )
 
-        defer.returnValue([ev for res in results for ev in res])
+        defer.returnValue([ev for res in results.values() for ev in res])
 
-    @defer.inlineCallbacks
+    @cachedInlineCallbacks(num_args=3, max_entries=5000)
     def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
         """Get receipts for a single room for sending to clients.
 
@@ -126,6 +121,61 @@ class ReceiptsStore(SQLBaseStore):
             "content": content,
         }])
 
+    @cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids",
+                num_args=3, inlineCallbacks=True)
+    def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+        if not room_ids:
+            defer.returnValue({})
+
+        def f(txn):
+            if from_key:
+                sql = (
+                    "SELECT * FROM receipts_linearized WHERE"
+                    " room_id IN (%s) AND stream_id > ? AND stream_id <= ?"
+                ) % (
+                    ",".join(["?"] * len(room_ids))
+                )
+                args = list(room_ids)
+                args.extend([from_key, to_key])
+
+                txn.execute(sql, args)
+            else:
+                sql = (
+                    "SELECT * FROM receipts_linearized WHERE"
+                    " room_id IN (%s) AND stream_id <= ?"
+                ) % (
+                    ",".join(["?"] * len(room_ids))
+                )
+
+                args = list(room_ids)
+                args.append(to_key)
+
+                txn.execute(sql, args)
+
+            return self.cursor_to_dict(txn)
+
+        txn_results = yield self.runInteraction(
+            "_get_linearized_receipts_for_rooms", f
+        )
+
+        results = {}
+        for row in txn_results:
+            results.setdefault(row["room_id"], {
+                "type": "m.receipt",
+                "room_id": row["room_id"],
+                "content": {},
+            })["content"].setdefault(
+                row["event_id"], {}
+            ).setdefault(
+                row["receipt_type"], {}
+            )[row["user_id"]] = json.loads(row["data"])
+
+        results = {
+            room_id: [results[room_id]] if room_id in results else []
+            for room_id in room_ids
+        }
+        defer.returnValue(results)
+
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_max_token(self)
 
@@ -305,6 +355,8 @@ class _RoomStreamChangeCache(object):
         self._room_to_key = {}
         self._cache = sorteddict()
         self._earliest_key = None
+        self.name = "ReceiptsRoomChangeCache"
+        caches_by_name[self.name] = self._cache
 
     @defer.inlineCallbacks
     def get_rooms_changed(self, store, room_ids, key):
@@ -318,8 +370,11 @@ class _RoomStreamChangeCache(object):
             result = set(
                 self._cache[k] for k in keys[i:]
             ).intersection(room_ids)
+
+            cache_counter.inc_hits(self.name)
         else:
             result = room_ids
+            cache_counter.inc_misses(self.name)
 
         defer.returnValue(result)