summary refs log tree commit diff
path: root/synapse/storage/receipts.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/receipts.py')
-rw-r--r--synapse/storage/receipts.py223
1 files changed, 136 insertions, 87 deletions
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index f42b8014c7..0ac665e967 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,52 +14,52 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
-from synapse.util.caches.stream_change_cache import StreamChangeCache
+import abc
+import logging
+
+from canonicaljson import json
 
 from twisted.internet import defer
 
-import logging
-import ujson as json
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 
+from ._base import SQLBaseStore
+from .util.id_generators import StreamIdGenerator
 
 logger = logging.getLogger(__name__)
 
 
-class ReceiptsStore(SQLBaseStore):
-    def __init__(self, hs):
-        super(ReceiptsStore, self).__init__(hs)
+class ReceiptsWorkerStore(SQLBaseStore):
+    """This is an abstract base class where subclasses must implement
+    `get_max_receipt_stream_id` which can be called in the initializer.
+    """
+
+    # This ABCMeta metaclass ensures that we cannot be instantiated without
+    # the abstract methods being implemented.
+    __metaclass__ = abc.ABCMeta
+
+    def __init__(self, db_conn, hs):
+        super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
 
         self._receipts_stream_cache = StreamChangeCache(
-            "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
+            "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
         )
 
+    @abc.abstractmethod
+    def get_max_receipt_stream_id(self):
+        """Get the current max stream ID for receipts stream
+
+        Returns:
+            int
+        """
+        raise NotImplementedError()
+
     @cachedInlineCallbacks()
     def get_users_with_read_receipts_in_room(self, room_id):
         receipts = yield self.get_receipts_for_room(room_id, "m.read")
         defer.returnValue(set(r['user_id'] for r in receipts))
 
-    def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
-                                                    user_id):
-        if receipt_type != "m.read":
-            return
-
-        # Returns an ObservableDeferred
-        res = self.get_users_with_read_receipts_in_room.cache.get(
-            room_id, None, update_metrics=False,
-        )
-
-        if res:
-            if isinstance(res, defer.Deferred) and res.called:
-                res = res.result
-            if user_id in res:
-                # We'd only be adding to the set, so no point invalidating if the
-                # user is already there
-                return
-
-        self.get_users_with_read_receipts_in_room.invalidate((room_id,))
-
     @cached(num_args=2)
     def get_receipts_for_room(self, room_id, receipt_type):
         return self._simple_select_list(
@@ -139,7 +140,9 @@ class ReceiptsStore(SQLBaseStore):
         """
         room_ids = set(room_ids)
 
-        if from_key:
+        if from_key is not None:
+            # Only ask the database about rooms where there have been new
+            # receipts added since `from_key`
             room_ids = yield self._receipts_stream_cache.get_entities_changed(
                 room_ids, from_key
             )
@@ -150,7 +153,6 @@ class ReceiptsStore(SQLBaseStore):
 
         defer.returnValue([ev for res in results.values() for ev in res])
 
-    @cachedInlineCallbacks(num_args=3, tree=True)
     def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
         """Get receipts for a single room for sending to clients.
 
@@ -161,7 +163,19 @@ class ReceiptsStore(SQLBaseStore):
                 from the start.
 
         Returns:
-            list: A list of receipts.
+            Deferred[list]: A list of receipts.
+        """
+        if from_key is not None:
+            # Check the cache first to see if any new receipts have been added
+            # since`from_key`. If not we can no-op.
+            if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
+                defer.succeed([])
+
+        return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
+
+    @cachedInlineCallbacks(num_args=3, tree=True)
+    def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+        """See get_linearized_receipts_for_room
         """
         def f(txn):
             if from_key:
@@ -210,7 +224,7 @@ class ReceiptsStore(SQLBaseStore):
             "content": content,
         }])
 
-    @cachedList(cached_method_name="get_linearized_receipts_for_room",
+    @cachedList(cached_method_name="_get_linearized_receipts_for_room",
                 list_name="room_ids", num_args=3, inlineCallbacks=True)
     def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
         if not room_ids:
@@ -270,11 +284,97 @@ class ReceiptsStore(SQLBaseStore):
         }
         defer.returnValue(results)
 
+    def get_all_updated_receipts(self, last_id, current_id, limit=None):
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_updated_receipts_txn(txn):
+            sql = (
+                "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
+                " FROM receipts_linearized"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " ORDER BY stream_id ASC"
+            )
+            args = [last_id, current_id]
+            if limit is not None:
+                sql += " LIMIT ?"
+                args.append(limit)
+            txn.execute(sql, args)
+
+            return txn.fetchall()
+        return self.runInteraction(
+            "get_all_updated_receipts", get_all_updated_receipts_txn
+        )
+
+    def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
+                                                    user_id):
+        if receipt_type != "m.read":
+            return
+
+        # Returns either an ObservableDeferred or the raw result
+        res = self.get_users_with_read_receipts_in_room.cache.get(
+            room_id, None, update_metrics=False,
+        )
+
+        # first handle the Deferred case
+        if isinstance(res, defer.Deferred):
+            if res.called:
+                res = res.result
+            else:
+                res = None
+
+        if res and user_id in res:
+            # We'd only be adding to the set, so no point invalidating if the
+            # user is already there
+            return
+
+        self.get_users_with_read_receipts_in_room.invalidate((room_id,))
+
+
+class ReceiptsStore(ReceiptsWorkerStore):
+    def __init__(self, db_conn, hs):
+        # We instantiate this first as the ReceiptsWorkerStore constructor
+        # needs to be able to call get_max_receipt_stream_id
+        self._receipts_id_gen = StreamIdGenerator(
+            db_conn, "receipts_linearized", "stream_id"
+        )
+
+        super(ReceiptsStore, self).__init__(db_conn, hs)
+
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_current_token()
 
     def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
                                       user_id, event_id, data, stream_id):
+        res = self._simple_select_one_txn(
+            txn,
+            table="events",
+            retcols=["topological_ordering", "stream_ordering"],
+            keyvalues={"event_id": event_id},
+            allow_none=True
+        )
+
+        stream_ordering = int(res["stream_ordering"]) if res else None
+
+        # We don't want to clobber receipts for more recent events, so we
+        # have to compare orderings of existing receipts
+        if stream_ordering is not None:
+            sql = (
+                "SELECT stream_ordering, event_id FROM events"
+                " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
+                " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
+            )
+            txn.execute(sql, (room_id, receipt_type, user_id))
+
+            for so, eid in txn:
+                if int(so) >= stream_ordering:
+                    logger.debug(
+                        "Ignoring new receipt for %s in favour of existing "
+                        "one for later event %s",
+                        event_id, eid,
+                    )
+                    return False
+
         txn.call_after(
             self.get_receipts_for_room.invalidate, (room_id, receipt_type)
         )
@@ -286,7 +386,7 @@ class ReceiptsStore(SQLBaseStore):
             self.get_receipts_for_user.invalidate, (user_id, receipt_type)
         )
         # FIXME: This shouldn't invalidate the whole cache
-        txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
+        txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
 
         txn.call_after(
             self._receipts_stream_cache.entity_has_changed,
@@ -298,34 +398,6 @@ class ReceiptsStore(SQLBaseStore):
             (user_id, room_id, receipt_type)
         )
 
-        res = self._simple_select_one_txn(
-            txn,
-            table="events",
-            retcols=["topological_ordering", "stream_ordering"],
-            keyvalues={"event_id": event_id},
-            allow_none=True
-        )
-
-        topological_ordering = int(res["topological_ordering"]) if res else None
-        stream_ordering = int(res["stream_ordering"]) if res else None
-
-        # We don't want to clobber receipts for more recent events, so we
-        # have to compare orderings of existing receipts
-        sql = (
-            "SELECT topological_ordering, stream_ordering, event_id FROM events"
-            " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
-            " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
-        )
-
-        txn.execute(sql, (room_id, receipt_type, user_id))
-
-        if topological_ordering:
-            for to, so, _ in txn:
-                if int(to) > topological_ordering:
-                    return False
-                elif int(to) == topological_ordering and int(so) >= stream_ordering:
-                    return False
-
         self._simple_delete_txn(
             txn,
             table="receipts_linearized",
@@ -349,12 +421,11 @@ class ReceiptsStore(SQLBaseStore):
             }
         )
 
-        if receipt_type == "m.read" and topological_ordering:
+        if receipt_type == "m.read" and stream_ordering is not None:
             self._remove_old_push_actions_before_txn(
                 txn,
                 room_id=room_id,
                 user_id=user_id,
-                topological_ordering=topological_ordering,
                 stream_ordering=stream_ordering,
             )
 
@@ -435,7 +506,7 @@ class ReceiptsStore(SQLBaseStore):
             self.get_receipts_for_user.invalidate, (user_id, receipt_type)
         )
         # FIXME: This shouldn't invalidate the whole cache
-        txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
+        txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
 
         self._simple_delete_txn(
             txn,
@@ -457,25 +528,3 @@ class ReceiptsStore(SQLBaseStore):
                 "data": json.dumps(data),
             }
         )
-
-    def get_all_updated_receipts(self, last_id, current_id, limit=None):
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_updated_receipts_txn(txn):
-            sql = (
-                "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
-                " FROM receipts_linearized"
-                " WHERE ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC"
-            )
-            args = [last_id, current_id]
-            if limit is not None:
-                sql += " LIMIT ?"
-                args.append(limit)
-            txn.execute(sql, args)
-
-            return txn.fetchall()
-        return self.runInteraction(
-            "get_all_updated_receipts", get_all_updated_receipts_txn
-        )