diff options
author | Erik Johnston <erik@matrix.org> | 2018-02-20 17:33:18 +0000 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2018-02-20 17:43:57 +0000 |
commit | e316bbb4c07cad97c4cff5bc0c5b0dc2cd7bc519 (patch) | |
tree | a2d1b661e9aa15b9d8c07100bdcd77dc5fe7bd9c | |
parent | Split ReceiptsStore (diff) | |
download | synapse-e316bbb4c07cad97c4cff5bc0c5b0dc2cd7bc519.tar.xz |
Use abstract base class to access stream IDs
Diffstat (limited to '')
-rw-r--r-- | synapse/replication/slave/storage/receipts.py | 9 | ||||
-rw-r--r-- | synapse/storage/receipts.py | 42 |
2 files changed, 34 insertions, 17 deletions
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 4e845ec041..a2eb4a02db 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -31,11 +31,16 @@ from synapse.storage.receipts import ReceiptsWorkerStore class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): - receipts_id_gen = SlavedIdTracker( + # We instansiate this first as the ReceiptsWorkerStore constructor + # needs to be able to call get_max_receipt_stream_id + self._receipts_id_gen = SlavedIdTracker( db_conn, "receipts_linearized", "stream_id" ) - super(SlavedReceiptsStore, self).__init__(receipts_id_gen, db_conn, hs) + super(SlavedReceiptsStore, self).__init__(db_conn, hs) + + def get_max_receipt_stream_id(self): + return self._receipts_id_gen.get_current_token() def stream_positions(self): result = super(SlavedReceiptsStore, self).stream_positions() diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index aa62474a46..b11cf7ff62 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -21,6 +21,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache from twisted.internet import defer +import abc import logging import ujson as json @@ -29,21 +30,30 @@ logger = logging.getLogger(__name__) class ReceiptsWorkerStore(SQLBaseStore): - def __init__(self, receipts_id_gen, db_conn, hs): - """ - Args: - receipts_id_gen (StreamIdGenerator|SlavedIdTracker) - db_conn: Database connection - hs (Homeserver) - """ - super(ReceiptsWorkerStore, self).__init__(db_conn, hs) + """This is an abstract base class where subclasses must implement + `get_max_receipt_stream_id` which can be called in the initializer. + """ - self._receipts_id_gen = receipts_id_gen + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + super(ReceiptsWorkerStore, self).__init__(db_conn, hs) self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() + "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() ) + @abc.abstractmethod + def get_max_receipt_stream_id(self): + """Get the current max stream ID for receipts stream + + Returns: + int + """ + pass + @cachedInlineCallbacks() def get_users_with_read_receipts_in_room(self, room_id): receipts = yield self.get_receipts_for_room(room_id, "m.read") @@ -260,9 +270,6 @@ class ReceiptsWorkerStore(SQLBaseStore): } defer.returnValue(results) - def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_current_token() - def get_all_updated_receipts(self, last_id, current_id, limit=None): if last_id == current_id: return defer.succeed([]) @@ -288,11 +295,16 @@ class ReceiptsWorkerStore(SQLBaseStore): class ReceiptsStore(ReceiptsWorkerStore): def __init__(self, db_conn, hs): - receipts_id_gen = StreamIdGenerator( + # We instansiate this first as the ReceiptsWorkerStore constructor + # needs to be able to call get_max_receipt_stream_id + self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" ) - super(ReceiptsStore, self).__init__(receipts_id_gen, db_conn, hs) + super(ReceiptsStore, self).__init__(db_conn, hs) + + def get_max_receipt_stream_id(self): + return self._receipts_id_gen.get_current_token() def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type, user_id): |