summary refs log tree commit diff
path: root/synapse/replication/slave/storage/receipts.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/slave/storage/receipts.py')
-rw-r--r--synapse/replication/slave/storage/receipts.py19
1 files changed, 6 insertions, 13 deletions
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py

index be716cc558..5c2986e050 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py
@@ -14,23 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore -from synapse.storage.database import Database +from synapse.replication.tcp.streams import ReceiptsStream +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker -# So, um, we want to borrow a load of functions intended for reading from -# a DataStore, but we don't want to take functions that either write to the -# DataStore or are cached and don't have cache invalidation logic. -# -# Rather than write duplicate versions of those functions, or lift them to -# a common base class, we going to grab the underlying __func__ object from -# the method descriptor on the DataStore and chuck them into our class. - class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = SlavedIdTracker( @@ -52,8 +45,8 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): self.get_receipts_for_room.invalidate((room_id, receipt_type)) def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "receipts": - self._receipts_id_gen.advance(token) + if stream_name == ReceiptsStream.NAME: + self._receipts_id_gen.advance(instance_name, token) for row in rows: self.invalidate_caches_for_receipt( row.room_id, row.receipt_type, row.user_id