summary refs log tree commit diff
path: root/synapse/storage/receipts.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/receipts.py')
-rw-r--r--synapse/storage/receipts.py42
1 files changed, 27 insertions, 15 deletions
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index aa62474a46..b11cf7ff62 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -21,6 +21,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from twisted.internet import defer
 
+import abc
 import logging
 import ujson as json
 
@@ -29,21 +30,30 @@ logger = logging.getLogger(__name__)
 
 
 class ReceiptsWorkerStore(SQLBaseStore):
-    def __init__(self, receipts_id_gen, db_conn, hs):
-        """
-        Args:
-            receipts_id_gen (StreamIdGenerator|SlavedIdTracker)
-            db_conn: Database connection
-            hs (Homeserver)
-        """
-        super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
+    """This is an abstract base class where subclasses must implement
+    `get_max_receipt_stream_id` which can be called in the initializer.
+    """
 
-        self._receipts_id_gen = receipts_id_gen
+    # This ABCMeta metaclass ensures that we cannot be instantiated without
+    # the abstract methods being implemented.
+    __metaclass__ = abc.ABCMeta
+
+    def __init__(self, db_conn, hs):
+        super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
 
         self._receipts_stream_cache = StreamChangeCache(
-            "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
+            "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
         )
 
+    @abc.abstractmethod
+    def get_max_receipt_stream_id(self):
+        """Get the current max stream ID for receipts stream
+
+        Returns:
+            int
+        """
+        pass
+
     @cachedInlineCallbacks()
     def get_users_with_read_receipts_in_room(self, room_id):
         receipts = yield self.get_receipts_for_room(room_id, "m.read")
@@ -260,9 +270,6 @@ class ReceiptsWorkerStore(SQLBaseStore):
         }
         defer.returnValue(results)
 
-    def get_max_receipt_stream_id(self):
-        return self._receipts_id_gen.get_current_token()
-
     def get_all_updated_receipts(self, last_id, current_id, limit=None):
         if last_id == current_id:
             return defer.succeed([])
@@ -288,11 +295,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
 class ReceiptsStore(ReceiptsWorkerStore):
     def __init__(self, db_conn, hs):
-        receipts_id_gen = StreamIdGenerator(
+        # We instansiate this first as the ReceiptsWorkerStore constructor
+        # needs to be able to call get_max_receipt_stream_id
+        self._receipts_id_gen = StreamIdGenerator(
             db_conn, "receipts_linearized", "stream_id"
         )
 
-        super(ReceiptsStore, self).__init__(receipts_id_gen, db_conn, hs)
+        super(ReceiptsStore, self).__init__(db_conn, hs)
+
+    def get_max_receipt_stream_id(self):
+        return self._receipts_id_gen.get_current_token()
 
     def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
                                                     user_id):