summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/_base.py18
-rw-r--r--synapse/replication/http/devices.py2
-rw-r--r--synapse/replication/http/federation.py17
-rw-r--r--synapse/replication/http/login.py2
-rw-r--r--synapse/replication/http/membership.py8
-rw-r--r--synapse/replication/http/presence.py4
-rw-r--r--synapse/replication/http/register.py4
-rw-r--r--synapse/replication/http/send_event.py7
-rw-r--r--synapse/replication/http/streams.py2
-rw-r--r--synapse/replication/slave/storage/_base.py6
-rw-r--r--synapse/replication/slave/storage/account_data.py8
-rw-r--r--synapse/replication/slave/storage/appservice.py2
-rw-r--r--synapse/replication/slave/storage/client_ips.py6
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py6
-rw-r--r--synapse/replication/slave/storage/devices.py8
-rw-r--r--synapse/replication/slave/storage/directory.py2
-rw-r--r--synapse/replication/slave/storage/events.py24
-rw-r--r--synapse/replication/slave/storage/filtering.py6
-rw-r--r--synapse/replication/slave/storage/groups.py6
-rw-r--r--synapse/replication/slave/storage/keys.py2
-rw-r--r--synapse/replication/slave/storage/presence.py6
-rw-r--r--synapse/replication/slave/storage/profile.py2
-rw-r--r--synapse/replication/slave/storage/push_rule.py2
-rw-r--r--synapse/replication/slave/storage/pushers.py6
-rw-r--r--synapse/replication/slave/storage/receipts.py6
-rw-r--r--synapse/replication/slave/storage/registration.py2
-rw-r--r--synapse/replication/slave/storage/room.py6
-rw-r--r--synapse/replication/slave/storage/transactions.py2
28 files changed, 81 insertions, 91 deletions
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index fb0dd04f88..6a28c2db9d 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -20,8 +20,6 @@ import urllib
 from inspect import signature
 from typing import Dict, List, Tuple
 
-from twisted.internet import defer
-
 from synapse.api.errors import (
     CodeMessageException,
     HttpResponseException,
@@ -101,7 +99,7 @@ class ReplicationEndpoint(object):
         assert self.METHOD in ("PUT", "POST", "GET")
 
     @abc.abstractmethod
-    def _serialize_payload(**kwargs):
+    async def _serialize_payload(**kwargs):
         """Static method that is called when creating a request.
 
         Concrete implementations should have explicit parameters (rather than
@@ -110,9 +108,8 @@ class ReplicationEndpoint(object):
         argument list.
 
         Returns:
-            Deferred[dict]|dict: If POST/PUT request then dictionary must be
-            JSON serialisable, otherwise must be appropriate for adding as
-            query args.
+            dict: If POST/PUT request then dictionary must be JSON serialisable,
+            otherwise must be appropriate for adding as query args.
         """
         return {}
 
@@ -144,8 +141,7 @@ class ReplicationEndpoint(object):
         instance_map = hs.config.worker.instance_map
 
         @trace(opname="outgoing_replication_request")
-        @defer.inlineCallbacks
-        def send_request(instance_name="master", **kwargs):
+        async def send_request(instance_name="master", **kwargs):
             if instance_name == local_instance_name:
                 raise Exception("Trying to send HTTP request to self")
             if instance_name == "master":
@@ -159,7 +155,7 @@ class ReplicationEndpoint(object):
                     "Instance %r not in 'instance_map' config" % (instance_name,)
                 )
 
-            data = yield cls._serialize_payload(**kwargs)
+            data = await cls._serialize_payload(**kwargs)
 
             url_args = [
                 urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
@@ -197,7 +193,7 @@ class ReplicationEndpoint(object):
                     headers = {}  # type: Dict[bytes, List[bytes]]
                     inject_active_span_byte_dict(headers, None, check_destination=False)
                     try:
-                        result = yield request_func(uri, data, headers=headers)
+                        result = await request_func(uri, data, headers=headers)
                         break
                     except CodeMessageException as e:
                         if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
@@ -207,7 +203,7 @@ class ReplicationEndpoint(object):
 
                     # If we timed out we probably don't need to worry about backing
                     # off too much, but lets just wait a little anyway.
-                    yield clock.sleep(1)
+                    await clock.sleep(1)
             except HttpResponseException as e:
                 # We convert to SynapseError as we know that it was a SynapseError
                 # on the master process that we should send to the client. (And
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index e32aac0a25..20f3ba76c0 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -60,7 +60,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
         self.clock = hs.get_clock()
 
     @staticmethod
-    def _serialize_payload(user_id):
+    async def _serialize_payload(user_id):
         return {}
 
     async def _handle_request(self, request, user_id):
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index ca065e819e..6b56315148 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -15,8 +15,6 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import make_event_from_dict
 from synapse.events.snapshot import EventContext
@@ -67,8 +65,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         self.federation_handler = hs.get_handlers().federation_handler
 
     @staticmethod
-    @defer.inlineCallbacks
-    def _serialize_payload(store, event_and_contexts, backfilled):
+    async def _serialize_payload(store, event_and_contexts, backfilled):
         """
         Args:
             store
@@ -78,9 +75,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         """
         event_payloads = []
         for event, context in event_and_contexts:
-            serialized_context = yield defer.ensureDeferred(
-                context.serialize(event, store)
-            )
+            serialized_context = await context.serialize(event, store)
 
             event_payloads.append(
                 {
@@ -156,7 +151,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
         self.registry = hs.get_federation_registry()
 
     @staticmethod
-    def _serialize_payload(edu_type, origin, content):
+    async def _serialize_payload(edu_type, origin, content):
         return {"origin": origin, "content": content}
 
     async def _handle_request(self, request, edu_type):
@@ -199,7 +194,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
         self.registry = hs.get_federation_registry()
 
     @staticmethod
-    def _serialize_payload(query_type, args):
+    async def _serialize_payload(query_type, args):
         """
         Args:
             query_type (str)
@@ -240,7 +235,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
         self.store = hs.get_datastore()
 
     @staticmethod
-    def _serialize_payload(room_id, args):
+    async def _serialize_payload(room_id, args):
         """
         Args:
             room_id (str)
@@ -275,7 +270,7 @@ class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
         self.store = hs.get_datastore()
 
     @staticmethod
-    def _serialize_payload(room_id, room_version):
+    async def _serialize_payload(room_id, room_version):
         return {"room_version": room_version.identifier}
 
     async def _handle_request(self, request, room_id):
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 798b9d3af5..fb326bb869 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -36,7 +36,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
         self.registration_handler = hs.get_registration_handler()
 
     @staticmethod
-    def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
+    async def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
         """
         Args:
             device_id (str|None): Device ID to use, if None a new one is
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 63ef6eb7be..741329ab5f 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -52,7 +52,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
         self.clock = hs.get_clock()
 
     @staticmethod
-    def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
+    async def _serialize_payload(
+        requester, room_id, user_id, remote_room_hosts, content
+    ):
         """
         Args:
             requester(Requester)
@@ -112,7 +114,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
         self.member_handler = hs.get_room_member_handler()
 
     @staticmethod
-    def _serialize_payload(  # type: ignore
+    async def _serialize_payload(  # type: ignore
         invite_event_id: str,
         txn_id: Optional[str],
         requester: Requester,
@@ -174,7 +176,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
         self.distributor = hs.get_distributor()
 
     @staticmethod
-    def _serialize_payload(room_id, user_id, change):
+    async def _serialize_payload(room_id, user_id, change):
         """
         Args:
             room_id (str)
diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py
index ea1b33331b..bc9aa82cb4 100644
--- a/synapse/replication/http/presence.py
+++ b/synapse/replication/http/presence.py
@@ -50,7 +50,7 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
         self._presence_handler = hs.get_presence_handler()
 
     @staticmethod
-    def _serialize_payload(user_id):
+    async def _serialize_payload(user_id):
         return {}
 
     async def _handle_request(self, request, user_id):
@@ -92,7 +92,7 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
         self._presence_handler = hs.get_presence_handler()
 
     @staticmethod
-    def _serialize_payload(user_id, state, ignore_status_msg=False):
+    async def _serialize_payload(user_id, state, ignore_status_msg=False):
         return {
             "state": state,
             "ignore_status_msg": ignore_status_msg,
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 0c4aca1291..ce9420aa69 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -34,7 +34,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
         self.registration_handler = hs.get_registration_handler()
 
     @staticmethod
-    def _serialize_payload(
+    async def _serialize_payload(
         user_id,
         password_hash,
         was_guest,
@@ -105,7 +105,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
         self.registration_handler = hs.get_registration_handler()
 
     @staticmethod
-    def _serialize_payload(user_id, auth_result, access_token):
+    async def _serialize_payload(user_id, auth_result, access_token):
         """
         Args:
             user_id (str): The user ID that consented
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index b30e4d5039..f13d452426 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -15,8 +15,6 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import make_event_from_dict
 from synapse.events.snapshot import EventContext
@@ -62,8 +60,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
         self.clock = hs.get_clock()
 
     @staticmethod
-    @defer.inlineCallbacks
-    def _serialize_payload(
+    async def _serialize_payload(
         event_id, store, event, context, requester, ratelimit, extra_users
     ):
         """
@@ -77,7 +74,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
             extra_users (list(UserID)): Any extra users to notify about event
         """
 
-        serialized_context = yield defer.ensureDeferred(context.serialize(event, store))
+        serialized_context = await context.serialize(event, store)
 
         payload = {
             "event": event.get_pdu_json(),
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
index bde97eef32..309159e304 100644
--- a/synapse/replication/http/streams.py
+++ b/synapse/replication/http/streams.py
@@ -54,7 +54,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
         self.streams = hs.get_replication_streams()
 
     @staticmethod
-    def _serialize_payload(stream_name, from_token, upto_token):
+    async def _serialize_payload(stream_name, from_token, upto_token):
         return {"from_token": from_token, "upto_token": upto_token}
 
     async def _handle_request(self, request, stream_name):
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index f9e2533e96..60f2e1245f 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -16,8 +16,8 @@
 import logging
 from typing import Optional
 
-from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
 
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
 
 
 class BaseSlavedStore(CacheInvalidationWorkerStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(BaseSlavedStore, self).__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
             self._cache_id_gen = MultiWriterIdGenerator(
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 525b94fd87..154f0e687c 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -17,13 +17,13 @@
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
-from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
-from synapse.storage.data_stores.main.tags import TagsWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.account_data import AccountDataWorkerStore
+from synapse.storage.databases.main.tags import TagsWorkerStore
 
 
 class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         self._account_data_id_gen = SlavedIdTracker(
             db_conn,
             "account_data",
diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py
index a67fbeffb7..0f8d7037bd 100644
--- a/synapse/replication/slave/storage/appservice.py
+++ b/synapse/replication/slave/storage/appservice.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.data_stores.main.appservice import (
+from synapse.storage.databases.main.appservice import (
     ApplicationServiceTransactionWorkerStore,
     ApplicationServiceWorkerStore,
 )
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 1a38f53dfb..60dd3f6701 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -13,15 +13,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
 from synapse.util.caches.descriptors import Cache
 
 from ._base import BaseSlavedStore
 
 
 class SlavedClientIpStore(BaseSlavedStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
 
         self.client_ip_last_seen = Cache(
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index a8a16dbc71..ee7f69a918 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -16,14 +16,14 @@
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import ToDeviceStream
-from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
 class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
         self._device_inbox_id_gen = SlavedIdTracker(
             db_conn, "device_inbox", "stream_id"
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 9d8067342f..722f3745e9 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -16,14 +16,14 @@
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
-from synapse.storage.data_stores.main.devices import DeviceWorkerStore
-from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.devices import DeviceWorkerStore
+from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
 class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(SlavedDeviceStore, self).__init__(database, db_conn, hs)
 
         self.hs = hs
diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py
index 8b9717c46f..1945bcf9a8 100644
--- a/synapse/replication/slave/storage/directory.py
+++ b/synapse/replication/slave/storage/directory.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.data_stores.main.directory import DirectoryWorkerStore
+from synapse.storage.databases.main.directory import DirectoryWorkerStore
 
 from ._base import BaseSlavedStore
 
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 1a1a50a24f..da1cc836cf 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -15,18 +15,18 @@
 # limitations under the License.
 import logging
 
-from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore
-from synapse.storage.data_stores.main.event_push_actions import (
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
+from synapse.storage.databases.main.event_push_actions import (
     EventPushActionsWorkerStore,
 )
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.data_stores.main.relations import RelationsWorkerStore
-from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
-from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
-from synapse.storage.data_stores.main.state import StateGroupWorkerStore
-from synapse.storage.data_stores.main.stream import StreamWorkerStore
-from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.databases.main.relations import RelationsWorkerStore
+from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
+from synapse.storage.databases.main.signatures import SignatureWorkerStore
+from synapse.storage.databases.main.state import StateGroupWorkerStore
+from synapse.storage.databases.main.stream import StreamWorkerStore
+from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from ._base import BaseSlavedStore
@@ -55,11 +55,11 @@ class SlavedEventStore(
     RelationsWorkerStore,
     BaseSlavedStore,
 ):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(SlavedEventStore, self).__init__(database, db_conn, hs)
 
         events_max = self._stream_id_gen.get_current_token()
-        curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
+        curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
             db_conn,
             "current_state_delta_stream",
             entity_column="room_id",
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index bcb0688954..2562b6fc38 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -13,14 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.data_stores.main.filtering import FilteringStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.filtering import FilteringStore
 
 from ._base import BaseSlavedStore
 
 
 class SlavedFilteringStore(BaseSlavedStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(SlavedFilteringStore, self).__init__(database, db_conn, hs)
 
     # Filters are immutable so this cache doesn't need to be expired
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 5d210fa3a1..3291558c7a 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -16,13 +16,13 @@
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import GroupServerStream
-from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.group_server import GroupServerWorkerStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
 class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(SlavedGroupServerStore, self).__init__(database, db_conn, hs)
 
         self.hs = hs
diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py
index 3def367ae9..961579751c 100644
--- a/synapse/replication/slave/storage/keys.py
+++ b/synapse/replication/slave/storage/keys.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.data_stores.main.keys import KeyStore
+from synapse.storage.databases.main.keys import KeyStore
 
 # KeyStore isn't really safe to use from a worker, but for now we do so and hope that
 # the races it creates aren't too bad.
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 2938cb8e43..a912c04360 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -15,8 +15,8 @@
 
 from synapse.replication.tcp.streams import PresenceStream
 from synapse.storage import DataStore
-from synapse.storage.data_stores.main.presence import PresenceStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.presence import PresenceStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from ._base import BaseSlavedStore
@@ -24,7 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedPresenceStore(BaseSlavedStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
         self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
 
diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py
index 28c508aad3..f85b20a071 100644
--- a/synapse/replication/slave/storage/profile.py
+++ b/synapse/replication/slave/storage/profile.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.storage.data_stores.main.profile import ProfileWorkerStore
+from synapse.storage.databases.main.profile import ProfileWorkerStore
 
 
 class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore):
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 23ec1c5b11..590187df46 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 from synapse.replication.tcp.streams import PushRulesStream
-from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
+from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
 
 from .events import SlavedEventStore
 
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index ff449f3658..63300e5da6 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -15,15 +15,15 @@
 # limitations under the License.
 
 from synapse.replication.tcp.streams import PushersStream
-from synapse.storage.data_stores.main.pusher import PusherWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.pusher import PusherWorkerStore
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(SlavedPusherStore, self).__init__(database, db_conn, hs)
         self._pushers_id_gen = SlavedIdTracker(
             db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 6982686eb5..17ba1f22ac 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -15,15 +15,15 @@
 # limitations under the License.
 
 from synapse.replication.tcp.streams import ReceiptsStream
-from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         # We instantiate this first as the ReceiptsWorkerStore constructor
         # needs to be able to call get_max_receipt_stream_id
         self._receipts_id_gen = SlavedIdTracker(
diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py
index 4b8553e250..a40f064e2b 100644
--- a/synapse/replication/slave/storage/registration.py
+++ b/synapse/replication/slave/storage/registration.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.data_stores.main.registration import RegistrationWorkerStore
+from synapse.storage.databases.main.registration import RegistrationWorkerStore
 
 from ._base import BaseSlavedStore
 
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 8710207ada..427c81772b 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -14,15 +14,15 @@
 # limitations under the License.
 
 from synapse.replication.tcp.streams import PublicRoomsStream
-from synapse.storage.data_stores.main.room import RoomWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.room import RoomWorkerStore
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
 
 class RoomStore(RoomWorkerStore, BaseSlavedStore):
-    def __init__(self, database: Database, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs):
         super(RoomStore, self).__init__(database, db_conn, hs)
         self._public_room_id_gen = SlavedIdTracker(
             db_conn, "public_room_list_stream", "stream_id"
diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py
index ac88e6b8c3..2091ac0df6 100644
--- a/synapse/replication/slave/storage/transactions.py
+++ b/synapse/replication/slave/storage/transactions.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.data_stores.main.transactions import TransactionStore
+from synapse.storage.databases.main.transactions import TransactionStore
 
 from ._base import BaseSlavedStore