summary refs log tree commit diff
path: root/synapse/replication/slave
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/slave')
-rw-r--r--synapse/replication/slave/storage/_base.py6
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py14
-rw-r--r--synapse/replication/slave/storage/account_data.py17
-rw-r--r--synapse/replication/slave/storage/appservice.py2
-rw-r--r--synapse/replication/slave/storage/client_ips.py8
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py13
-rw-r--r--synapse/replication/slave/storage/devices.py15
-rw-r--r--synapse/replication/slave/storage/directory.py2
-rw-r--r--synapse/replication/slave/storage/events.py24
-rw-r--r--synapse/replication/slave/storage/filtering.py6
-rw-r--r--synapse/replication/slave/storage/groups.py11
-rw-r--r--synapse/replication/slave/storage/keys.py2
-rw-r--r--synapse/replication/slave/storage/presence.py11
-rw-r--r--synapse/replication/slave/storage/profile.py2
-rw-r--r--synapse/replication/slave/storage/push_rule.py17
-rw-r--r--synapse/replication/slave/storage/pushers.py11
-rw-r--r--synapse/replication/slave/storage/receipts.py19
-rw-r--r--synapse/replication/slave/storage/registration.py2
-rw-r--r--synapse/replication/slave/storage/room.py11
-rw-r--r--synapse/replication/slave/storage/transactions.py2
20 files changed, 102 insertions, 93 deletions
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index f9e2533e96..60f2e1245f 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -16,8 +16,8 @@
 import logging
 from typing import Optional
 
-from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
 
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
 
 
 class BaseSlavedStore(CacheInvalidationWorkerStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(BaseSlavedStore, self).__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
             self._cache_id_gen = MultiWriterIdGenerator(
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 9d1d173b2f..eb74903d68 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -16,14 +16,14 @@
 from synapse.storage.util.id_generators import _load_current_id
 
 
-class SlavedIdTracker(object):
+class SlavedIdTracker:
     def __init__(self, db_conn, table, column, extra_tables=[], step=1):
         self.step = step
         self._current = _load_current_id(db_conn, table, column, step)
         for table, column in extra_tables:
-            self.advance(_load_current_id(db_conn, table, column))
+            self.advance(None, _load_current_id(db_conn, table, column))
 
-    def advance(self, new_id):
+    def advance(self, instance_name, new_id):
         self._current = (max if self.step > 0 else min)(self._current, new_id)
 
     def get_current_token(self):
@@ -33,3 +33,11 @@ class SlavedIdTracker(object):
             int
         """
         return self._current
+
+    def get_current_token_for_writer(self, instance_name: str) -> int:
+        """Returns the position of the given writer.
+
+        For streams with single writers this is equivalent to
+        `get_current_token`.
+        """
+        return self.get_current_token()
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 9db6c62bc7..bb66ba9b80 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -16,13 +16,14 @@
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
-from synapse.storage.data_stores.main.tags import TagsWorkerStore
-from synapse.storage.database import Database
+from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.account_data import AccountDataWorkerStore
+from synapse.storage.databases.main.tags import TagsWorkerStore
 
 
 class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         self._account_data_id_gen = SlavedIdTracker(
             db_conn,
             "account_data",
@@ -39,13 +40,13 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
         return self._account_data_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "tag_account_data":
-            self._account_data_id_gen.advance(token)
+        if stream_name == TagAccountDataStream.NAME:
+            self._account_data_id_gen.advance(instance_name, token)
             for row in rows:
                 self.get_tags_for_user.invalidate((row.user_id,))
                 self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-        elif stream_name == "account_data":
-            self._account_data_id_gen.advance(token)
+        elif stream_name == AccountDataStream.NAME:
+            self._account_data_id_gen.advance(instance_name, token)
             for row in rows:
                 if not row.room_id:
                     self.get_global_account_data_by_type_for_user.invalidate(
diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py
index a67fbeffb7..0f8d7037bd 100644
--- a/synapse/replication/slave/storage/appservice.py
+++ b/synapse/replication/slave/storage/appservice.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.data_stores.main.appservice import (
+from synapse.storage.databases.main.appservice import (
     ApplicationServiceTransactionWorkerStore,
     ApplicationServiceWorkerStore,
 )
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 1a38f53dfb..a6fdedde63 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -13,22 +13,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
 from synapse.util.caches.descriptors import Cache
 
 from ._base import BaseSlavedStore
 
 
 class SlavedClientIpStore(BaseSlavedStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
 
         self.client_ip_last_seen = Cache(
             name="client_ip_last_seen", keylen=4, max_entries=50000
         )
 
-    def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
+    async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
         now = int(self._clock.time_msec())
         key = (user_id, access_token, ip)
 
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 6e7fd259d4..533d927701 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -15,17 +15,18 @@
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
-from synapse.storage.database import Database
+from synapse.replication.tcp.streams import ToDeviceStream
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
 class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
         self._device_inbox_id_gen = SlavedIdTracker(
-            db_conn, "device_max_stream_id", "stream_id"
+            db_conn, "device_inbox", "stream_id"
         )
         self._device_inbox_stream_cache = StreamChangeCache(
             "DeviceInboxStreamChangeCache",
@@ -44,8 +45,8 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
         )
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "to_device":
-            self._device_inbox_id_gen.advance(token)
+        if stream_name == ToDeviceStream.NAME:
+            self._device_inbox_id_gen.advance(instance_name, token)
             for row in rows:
                 if row.entity.startswith("@"):
                     self._device_inbox_stream_cache.entity_has_changed(
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 9d8067342f..3b788c9625 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -16,14 +16,14 @@
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
-from synapse.storage.data_stores.main.devices import DeviceWorkerStore
-from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.devices import DeviceWorkerStore
+from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
 class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(SlavedDeviceStore, self).__init__(database, db_conn, hs)
 
         self.hs = hs
@@ -48,12 +48,15 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
             "DeviceListFederationStreamChangeCache", device_list_max
         )
 
+    def get_device_stream_token(self) -> int:
+        return self._device_list_id_gen.get_current_token()
+
     def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == DeviceListsStream.NAME:
-            self._device_list_id_gen.advance(token)
+            self._device_list_id_gen.advance(instance_name, token)
             self._invalidate_caches_for_devices(token, rows)
         elif stream_name == UserSignatureStream.NAME:
-            self._device_list_id_gen.advance(token)
+            self._device_list_id_gen.advance(instance_name, token)
             for row in rows:
                 self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
         return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py
index 8b9717c46f..1945bcf9a8 100644
--- a/synapse/replication/slave/storage/directory.py
+++ b/synapse/replication/slave/storage/directory.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.data_stores.main.directory import DirectoryWorkerStore
+from synapse.storage.databases.main.directory import DirectoryWorkerStore
 
 from ._base import BaseSlavedStore
 
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 1a1a50a24f..da1cc836cf 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -15,18 +15,18 @@
 # limitations under the License.
 import logging
 
-from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore
-from synapse.storage.data_stores.main.event_push_actions import (
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
+from synapse.storage.databases.main.event_push_actions import (
     EventPushActionsWorkerStore,
 )
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.data_stores.main.relations import RelationsWorkerStore
-from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
-from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
-from synapse.storage.data_stores.main.state import StateGroupWorkerStore
-from synapse.storage.data_stores.main.stream import StreamWorkerStore
-from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.databases.main.relations import RelationsWorkerStore
+from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
+from synapse.storage.databases.main.signatures import SignatureWorkerStore
+from synapse.storage.databases.main.state import StateGroupWorkerStore
+from synapse.storage.databases.main.stream import StreamWorkerStore
+from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from ._base import BaseSlavedStore
@@ -55,11 +55,11 @@ class SlavedEventStore(
     RelationsWorkerStore,
     BaseSlavedStore,
 ):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(SlavedEventStore, self).__init__(database, db_conn, hs)
 
         events_max = self._stream_id_gen.get_current_token()
-        curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
+        curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
             db_conn,
             "current_state_delta_stream",
             entity_column="room_id",
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index bcb0688954..2562b6fc38 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -13,14 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.data_stores.main.filtering import FilteringStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.filtering import FilteringStore
 
 from ._base import BaseSlavedStore
 
 
 class SlavedFilteringStore(BaseSlavedStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(SlavedFilteringStore, self).__init__(database, db_conn, hs)
 
     # Filters are immutable so this cache doesn't need to be expired
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 1851e7d525..567b4a5cc1 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -15,13 +15,14 @@
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore
-from synapse.storage.database import Database
+from synapse.replication.tcp.streams import GroupServerStream
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.group_server import GroupServerWorkerStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
 class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(SlavedGroupServerStore, self).__init__(database, db_conn, hs)
 
         self.hs = hs
@@ -38,8 +39,8 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
         return self._group_updates_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "groups":
-            self._group_updates_id_gen.advance(token)
+        if stream_name == GroupServerStream.NAME:
+            self._group_updates_id_gen.advance(instance_name, token)
             for row in rows:
                 self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
 
diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py
index 3def367ae9..961579751c 100644
--- a/synapse/replication/slave/storage/keys.py
+++ b/synapse/replication/slave/storage/keys.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.data_stores.main.keys import KeyStore
+from synapse.storage.databases.main.keys import KeyStore
 
 # KeyStore isn't really safe to use from a worker, but for now we do so and hope that
 # the races it creates aren't too bad.
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 4e0124842d..025f6f6be8 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -13,9 +13,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.replication.tcp.streams import PresenceStream
 from synapse.storage import DataStore
-from synapse.storage.data_stores.main.presence import PresenceStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.presence import PresenceStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from ._base import BaseSlavedStore
@@ -23,7 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedPresenceStore(BaseSlavedStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
         self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
 
@@ -42,8 +43,8 @@ class SlavedPresenceStore(BaseSlavedStore):
         return self._presence_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "presence":
-            self._presence_id_gen.advance(token)
+        if stream_name == PresenceStream.NAME:
+            self._presence_id_gen.advance(instance_name, token)
             for row in rows:
                 self.presence_stream_cache.entity_has_changed(row.user_id, token)
                 self._get_presence_for_user.invalidate((row.user_id,))
diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py
index 28c508aad3..f85b20a071 100644
--- a/synapse/replication/slave/storage/profile.py
+++ b/synapse/replication/slave/storage/profile.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.storage.data_stores.main.profile import ProfileWorkerStore
+from synapse.storage.databases.main.profile import ProfileWorkerStore
 
 
 class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore):
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 6adb19463a..de904c943c 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -14,24 +14,23 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import PushRulesStream
+from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
 
 from .events import SlavedEventStore
 
 
 class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
-    def get_push_rules_stream_token(self):
-        return (
-            self._push_rules_stream_id_gen.get_current_token(),
-            self._stream_id_gen.get_current_token(),
-        )
-
     def get_max_push_rules_stream_id(self):
         return self._push_rules_stream_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "push_rules":
-            self._push_rules_stream_id_gen.advance(token)
+        # We assert this for the benefit of mypy
+        assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
+
+        if stream_name == PushRulesStream.NAME:
+            self._push_rules_stream_id_gen.advance(instance_name, token)
             for row in rows:
                 self.get_push_rules_for_user.invalidate((row.user_id,))
                 self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index cb78b49acb..9da218bfe8 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -14,15 +14,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.data_stores.main.pusher import PusherWorkerStore
-from synapse.storage.database import Database
+from synapse.replication.tcp.streams import PushersStream
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.pusher import PusherWorkerStore
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(SlavedPusherStore, self).__init__(database, db_conn, hs)
         self._pushers_id_gen = SlavedIdTracker(
             db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
@@ -32,6 +33,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
         return self._pushers_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "pushers":
-            self._pushers_id_gen.advance(token)
+        if stream_name == PushersStream.NAME:
+            self._pushers_id_gen.advance(instance_name, token)
         return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index be716cc558..5c2986e050 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -14,23 +14,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
-from synapse.storage.database import Database
+from synapse.replication.tcp.streams import ReceiptsStream
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
-# So, um, we want to borrow a load of functions intended for reading from
-# a DataStore, but we don't want to take functions that either write to the
-# DataStore or are cached and don't have cache invalidation logic.
-#
-# Rather than write duplicate versions of those functions, or lift them to
-# a common base class, we going to grab the underlying __func__ object from
-# the method descriptor on the DataStore and chuck them into our class.
-
 
 class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         # We instantiate this first as the ReceiptsWorkerStore constructor
         # needs to be able to call get_max_receipt_stream_id
         self._receipts_id_gen = SlavedIdTracker(
@@ -52,8 +45,8 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
         self.get_receipts_for_room.invalidate((room_id, receipt_type))
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "receipts":
-            self._receipts_id_gen.advance(token)
+        if stream_name == ReceiptsStream.NAME:
+            self._receipts_id_gen.advance(instance_name, token)
             for row in rows:
                 self.invalidate_caches_for_receipt(
                     row.room_id, row.receipt_type, row.user_id
diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py
index 4b8553e250..a40f064e2b 100644
--- a/synapse/replication/slave/storage/registration.py
+++ b/synapse/replication/slave/storage/registration.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.data_stores.main.registration import RegistrationWorkerStore
+from synapse.storage.databases.main.registration import RegistrationWorkerStore
 
 from ._base import BaseSlavedStore
 
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 8873bf37e5..80ae803ad9 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -13,15 +13,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.data_stores.main.room import RoomWorkerStore
-from synapse.storage.database import Database
+from synapse.replication.tcp.streams import PublicRoomsStream
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.room import RoomWorkerStore
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
 
 class RoomStore(RoomWorkerStore, BaseSlavedStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(RoomStore, self).__init__(database, db_conn, hs)
         self._public_room_id_gen = SlavedIdTracker(
             db_conn, "public_room_list_stream", "stream_id"
@@ -31,7 +32,7 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
         return self._public_room_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == "public_rooms":
-            self._public_room_id_gen.advance(token)
+        if stream_name == PublicRoomsStream.NAME:
+            self._public_room_id_gen.advance(instance_name, token)
 
         return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py
index ac88e6b8c3..2091ac0df6 100644
--- a/synapse/replication/slave/storage/transactions.py
+++ b/synapse/replication/slave/storage/transactions.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.data_stores.main.transactions import TransactionStore
+from synapse.storage.databases.main.transactions import TransactionStore
 
 from ._base import BaseSlavedStore