summary refs log tree commit diff
path: root/synapse/replication/slave/storage/account_data.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/slave/storage/account_data.py')
-rw-r--r--synapse/replication/slave/storage/account_data.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py

index 9db6c62bc7..154f0e687c 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py
@@ -16,13 +16,14 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore -from synapse.storage.data_stores.main.tags import TagsWorkerStore -from synapse.storage.database import Database +from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.account_data import AccountDataWorkerStore +from synapse.storage.databases.main.tags import TagsWorkerStore class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): self._account_data_id_gen = SlavedIdTracker( db_conn, "account_data", @@ -39,12 +40,12 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved return self._account_data_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "tag_account_data": + if stream_name == TagAccountDataStream.NAME: self._account_data_id_gen.advance(token) for row in rows: self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed(row.user_id, token) - elif stream_name == "account_data": + elif stream_name == AccountDataStream.NAME: self._account_data_id_gen.advance(token) for row in rows: if not row.room_id: