diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index aa62474a46..b11cf7ff62 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -21,6 +21,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
from twisted.internet import defer
+import abc
import logging
import ujson as json
@@ -29,21 +30,30 @@ logger = logging.getLogger(__name__)
class ReceiptsWorkerStore(SQLBaseStore):
- def __init__(self, receipts_id_gen, db_conn, hs):
- """
- Args:
- receipts_id_gen (StreamIdGenerator|SlavedIdTracker)
- db_conn: Database connection
- hs (Homeserver)
- """
- super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
+ """This is an abstract base class where subclasses must implement
+ `get_max_receipt_stream_id` which can be called in the initializer.
+ """
- self._receipts_id_gen = receipts_id_gen
+ # This ABCMeta metaclass ensures that we cannot be instantiated without
+ # the abstract methods being implemented.
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, db_conn, hs):
+ super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
self._receipts_stream_cache = StreamChangeCache(
- "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
+ "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
)
+ @abc.abstractmethod
+ def get_max_receipt_stream_id(self):
+ """Get the current max stream ID for receipts stream
+
+ Returns:
+ int
+ """
+ pass
+
@cachedInlineCallbacks()
def get_users_with_read_receipts_in_room(self, room_id):
receipts = yield self.get_receipts_for_room(room_id, "m.read")
@@ -260,9 +270,6 @@ class ReceiptsWorkerStore(SQLBaseStore):
}
defer.returnValue(results)
- def get_max_receipt_stream_id(self):
- return self._receipts_id_gen.get_current_token()
-
def get_all_updated_receipts(self, last_id, current_id, limit=None):
if last_id == current_id:
return defer.succeed([])
@@ -288,11 +295,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
class ReceiptsStore(ReceiptsWorkerStore):
def __init__(self, db_conn, hs):
- receipts_id_gen = StreamIdGenerator(
+ # We instansiate this first as the ReceiptsWorkerStore constructor
+ # needs to be able to call get_max_receipt_stream_id
+ self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
)
- super(ReceiptsStore, self).__init__(receipts_id_gen, db_conn, hs)
+ super(ReceiptsStore, self).__init__(db_conn, hs)
+
+ def get_max_receipt_stream_id(self):
+ return self._receipts_id_gen.get_current_token()
def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
user_id):
|