summary refs log tree commit diff
path: root/synapse/storage/receipts.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/storage/receipts.py92
1 files changed, 89 insertions, 3 deletions
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 07f8edaace..503f68f858 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -17,6 +17,9 @@ from ._base import SQLBaseStore, cached
 
 from twisted.internet import defer
 
+from synapse.util import unwrapFirstError
+
+from blist import sorteddict
 import logging
 
 
@@ -24,6 +27,29 @@ logger = logging.getLogger(__name__)
 
 
 class ReceiptsStore(SQLBaseStore):
+    def __init__(self, hs):
+        super(ReceiptsStore, self).__init__(hs)
+
+        self._receipts_stream_cache = _RoomStreamChangeCache()
+
+    @defer.inlineCallbacks
+    def get_linearized_receipts_for_rooms(self, room_ids, from_key, to_key):
+        room_ids = set(room_ids)
+
+        if from_key:
+            room_ids = yield self._receipts_stream_cache.get_rooms_changed(
+                self, room_ids, from_key
+            )
+
+        results = yield defer.gatherResults(
+            [
+                self.get_linearized_receipts_for_room(room_id, from_key, to_key)
+                for room_id in room_ids
+            ],
+            consumeErrors=True,
+        ).addErrback(unwrapFirstError)
+
+        defer.returnValue([ev for res in results for ev in res])
 
     @defer.inlineCallbacks
     def get_linearized_receipts_for_room(self, room_id, from_key, to_key):
@@ -57,15 +83,22 @@ class ReceiptsStore(SQLBaseStore):
             "get_linearized_receipts_for_room", f
         )
 
-        result = {}
+        if not rows:
+            defer.returnValue([])
+
+        content = {}
         for row in rows:
-            result.setdefault(
+            content.setdefault(
                 row["event_id"], {}
             ).setdefault(
                 row["receipt_type"], []
             ).append(row["user_id"])
 
-        defer.returnValue(result)
+        defer.returnValue([{
+            "type": "m.receipt",
+            "room_id": room_id,
+            "content": content,
+        }])
 
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_max_token(self)
@@ -174,6 +207,9 @@ class ReceiptsStore(SQLBaseStore):
 
         stream_id_manager = yield self._receipts_id_gen.get_next(self)
         with stream_id_manager as stream_id:
+            yield self._receipts_stream_cache.room_has_changed(
+                self, room_id, stream_id
+            )
             have_persisted = yield self.runInteraction(
                 "insert_linearized_receipt",
                 self.insert_linearized_receipt_txn,
@@ -223,3 +259,53 @@ class ReceiptsStore(SQLBaseStore):
                 for event_id in event_ids
             ],
         )
+
+
+class _RoomStreamChangeCache(object):
+    """Keeps track of the stream_id of the latest change in rooms.
+
+    Given a list of rooms and stream key, it will give a subset of rooms that
+    may have changed since that key. If the key is too old then the cache
+    will simply return all rooms.
+    """
+    def __init__(self, size_of_cache=1000):
+        self._size_of_cache = size_of_cache
+        self._room_to_key = {}
+        self._cache = sorteddict()
+        self._earliest_key = None
+
+    @defer.inlineCallbacks
+    def get_rooms_changed(self, store, room_ids, key):
+        if key > (yield self._get_earliest_key(store)):
+            keys = self._cache.keys()
+            i = keys.bisect_right(key)
+
+            result = set(
+                self._cache[k] for k in keys[i:]
+            ).intersection(room_ids)
+        else:
+            result = room_ids
+
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def room_has_changed(self, store, room_id, key):
+        if key > (yield self._get_earliest_key(store)):
+            old_key = self._room_to_key.get(room_id, None)
+            if old_key:
+                key = max(key, old_key)
+                self._cache.pop(old_key, None)
+            self._cache[key] = room_id
+
+            while len(self._cache) > self._size_of_cache:
+                k, r = self._cache.popitem()
+                self._earliest_key = max(k, self._earliest_key)
+                self._room_to_key.pop(r, None)
+
+    @defer.inlineCallbacks
+    def _get_earliest_key(self, store):
+        if self._earliest_key is None:
+            self._earliest_key = yield store.get_max_receipt_stream_id()
+            self._earliest_key = int(self._earliest_key)
+
+        defer.returnValue(self._earliest_key)