diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 456bc005a0..b91a528245 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -18,7 +18,9 @@ from typing import Dict
import six
-from synapse.storage._base import _CURRENT_STATE_CACHE_NAME, SQLBaseStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
+from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from ._slaved_id_tracker import SlavedIdTracker
@@ -34,8 +36,8 @@ def __func__(inp):
class BaseSlavedStore(SQLBaseStore):
- def __init__(self, db_conn, hs):
- super(BaseSlavedStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id"
@@ -62,7 +64,7 @@ class BaseSlavedStore(SQLBaseStore):
if stream_name == "caches":
self._cache_id_gen.advance(token)
for row in rows:
- if row.cache_func == _CURRENT_STATE_CACHE_NAME:
+ if row.cache_func == CURRENT_STATE_CACHE_NAME:
room_id = row.keys[0]
members_changed = set(row.keys[1:])
self._invalidate_state_caches(room_id, members_changed)
|