summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2019-12-06 13:08:40 +0000
committerErik Johnston <erik@matrix.org>2019-12-06 13:30:06 +0000
commit9a4fb457cf5918c85068ea249cd2d58b3e2e3cfc (patch)
tree4bb85c953cb523732bcc53ec01b7c1f8209ac1ad
parentMerge pull request #6469 from matrix-org/erikj/make_database_class (diff)
downloadsynapse-9a4fb457cf5918c85068ea249cd2d58b3e2e3cfc.tar.xz
Change DataStores to accept 'database' param.
-rw-r--r--synapse/app/federation_sender.py5
-rw-r--r--synapse/app/user_dir.py5
-rw-r--r--synapse/replication/slave/storage/_base.py5
-rw-r--r--synapse/replication/slave/storage/account_data.py5
-rw-r--r--synapse/replication/slave/storage/client_ips.py5
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py5
-rw-r--r--synapse/replication/slave/storage/devices.py5
-rw-r--r--synapse/replication/slave/storage/events.py5
-rw-r--r--synapse/replication/slave/storage/filtering.py5
-rw-r--r--synapse/replication/slave/storage/groups.py5
-rw-r--r--synapse/replication/slave/storage/presence.py5
-rw-r--r--synapse/replication/slave/storage/push_rule.py5
-rw-r--r--synapse/replication/slave/storage/pushers.py5
-rw-r--r--synapse/replication/slave/storage/receipts.py5
-rw-r--r--synapse/replication/slave/storage/room.py5
-rw-r--r--synapse/storage/_base.py2
-rw-r--r--synapse/storage/data_stores/main/__init__.py5
-rw-r--r--synapse/storage/data_stores/main/account_data.py9
-rw-r--r--synapse/storage/data_stores/main/appservice.py5
-rw-r--r--synapse/storage/data_stores/main/client_ips.py9
-rw-r--r--synapse/storage/data_stores/main/deviceinbox.py9
-rw-r--r--synapse/storage/data_stores/main/devices.py9
-rw-r--r--synapse/storage/data_stores/main/event_federation.py5
-rw-r--r--synapse/storage/data_stores/main/event_push_actions.py9
-rw-r--r--synapse/storage/data_stores/main/events.py5
-rw-r--r--synapse/storage/data_stores/main/events_bg_updates.py5
-rw-r--r--synapse/storage/data_stores/main/events_worker.py5
-rw-r--r--synapse/storage/data_stores/main/media_repository.py11
-rw-r--r--synapse/storage/data_stores/main/monthly_active_users.py7
-rw-r--r--synapse/storage/data_stores/main/push_rule.py5
-rw-r--r--synapse/storage/data_stores/main/receipts.py9
-rw-r--r--synapse/storage/data_stores/main/registration.py13
-rw-r--r--synapse/storage/data_stores/main/room.py9
-rw-r--r--synapse/storage/data_stores/main/roommember.py13
-rw-r--r--synapse/storage/data_stores/main/search.py9
-rw-r--r--synapse/storage/data_stores/main/state.py13
-rw-r--r--synapse/storage/data_stores/main/stats.py5
-rw-r--r--synapse/storage/data_stores/main/stream.py5
-rw-r--r--synapse/storage/data_stores/main/transactions.py5
-rw-r--r--synapse/storage/data_stores/main/user_directory.py9
-rw-r--r--tests/storage/test_appservice.py5
41 files changed, 156 insertions, 114 deletions
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 448e45e00f..f24920a7d6 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -40,6 +40,7 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor
 from synapse.replication.tcp.client import ReplicationClientHandler
 from synapse.replication.tcp.streams._base import ReceiptsStream
 from synapse.server import HomeServer
+from synapse.storage.database import Database
 from synapse.storage.engines import create_engine
 from synapse.types import ReadReceipt
 from synapse.util.async_helpers import Linearizer
@@ -59,8 +60,8 @@ class FederationSenderSlaveStore(
     SlavedDeviceStore,
     SlavedPresenceStore,
 ):
-    def __init__(self, db_conn, hs):
-        super(FederationSenderSlaveStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(FederationSenderSlaveStore, self).__init__(database, db_conn, hs)
 
         # We pull out the current federation stream position now so that we
         # always have a known value for the federation position in memory so
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index b6d4481725..c01fb34a9b 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -43,6 +43,7 @@ from synapse.replication.tcp.streams.events import (
 from synapse.rest.client.v2_alpha import user_directory
 from synapse.server import HomeServer
 from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
+from synapse.storage.database import Database
 from synapse.storage.engines import create_engine
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.httpresourcetree import create_resource_tree
@@ -60,8 +61,8 @@ class UserDirectorySlaveStore(
     UserDirectoryStore,
     BaseSlavedStore,
 ):
-    def __init__(self, db_conn, hs):
-        super(UserDirectorySlaveStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(UserDirectorySlaveStore, self).__init__(database, db_conn, hs)
 
         events_max = self._stream_id_gen.get_current_token()
         curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 6ece1d6745..b91a528245 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -20,6 +20,7 @@ import six
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 
 from ._slaved_id_tracker import SlavedIdTracker
@@ -35,8 +36,8 @@ def __func__(inp):
 
 
 class BaseSlavedStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(BaseSlavedStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(BaseSlavedStore, self).__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
             self._cache_id_gen = SlavedIdTracker(
                 db_conn, "cache_invalidation_stream", "stream_id"
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index bc2f6a12ae..ebe94909cb 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -18,15 +18,16 @@ from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
 from synapse.storage.data_stores.main.tags import TagsWorkerStore
+from synapse.storage.database import Database
 
 
 class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self._account_data_id_gen = SlavedIdTracker(
             db_conn, "account_data_max_stream_id", "stream_id"
         )
 
-        super(SlavedAccountDataStore, self).__init__(db_conn, hs)
+        super(SlavedAccountDataStore, self).__init__(database, db_conn, hs)
 
     def get_max_account_data_stream_id(self):
         return self._account_data_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index b4f58cea19..fbf996e33a 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY
+from synapse.storage.database import Database
 from synapse.util.caches import CACHE_SIZE_FACTOR
 from synapse.util.caches.descriptors import Cache
 
@@ -21,8 +22,8 @@ from ._base import BaseSlavedStore
 
 
 class SlavedClientIpStore(BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedClientIpStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
 
         self.client_ip_last_seen = Cache(
             name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 9fb6c5c6ff..0c237c6e0f 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -16,13 +16,14 @@
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
+from synapse.storage.database import Database
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
 class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
         self._device_inbox_id_gen = SlavedIdTracker(
             db_conn, "device_max_stream_id", "stream_id"
         )
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index de50748c30..dc625e0d7a 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -18,12 +18,13 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
 from synapse.storage.data_stores.main.devices import DeviceWorkerStore
 from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
+from synapse.storage.database import Database
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
 class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedDeviceStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedDeviceStore, self).__init__(database, db_conn, hs)
 
         self.hs = hs
 
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index d0a0eaf75b..29f35b9915 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -31,6 +31,7 @@ from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
 from synapse.storage.data_stores.main.state import StateGroupWorkerStore
 from synapse.storage.data_stores.main.stream import StreamWorkerStore
 from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore
+from synapse.storage.database import Database
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
@@ -59,13 +60,13 @@ class SlavedEventStore(
     RelationsWorkerStore,
     BaseSlavedStore,
 ):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
         self._backfill_id_gen = SlavedIdTracker(
             db_conn, "events", "stream_ordering", step=-1
         )
 
-        super(SlavedEventStore, self).__init__(db_conn, hs)
+        super(SlavedEventStore, self).__init__(database, db_conn, hs)
 
     # Cached functions can't be accessed through a class instance so we need
     # to reach inside the __dict__ to extract them.
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index 5c84ebd125..bcb0688954 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -14,13 +14,14 @@
 # limitations under the License.
 
 from synapse.storage.data_stores.main.filtering import FilteringStore
+from synapse.storage.database import Database
 
 from ._base import BaseSlavedStore
 
 
 class SlavedFilteringStore(BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedFilteringStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedFilteringStore, self).__init__(database, db_conn, hs)
 
     # Filters are immutable so this cache doesn't need to be expired
     get_user_filter = FilteringStore.__dict__["get_user_filter"]
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 28a46edd28..69a4ae42f9 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 from synapse.storage import DataStore
+from synapse.storage.database import Database
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from ._base import BaseSlavedStore, __func__
@@ -21,8 +22,8 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedGroupServerStore(BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedGroupServerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedGroupServerStore, self).__init__(database, db_conn, hs)
 
         self.hs = hs
 
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 747ced0c84..f552e7c972 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -15,6 +15,7 @@
 
 from synapse.storage import DataStore
 from synapse.storage.data_stores.main.presence import PresenceStore
+from synapse.storage.database import Database
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from ._base import BaseSlavedStore, __func__
@@ -22,8 +23,8 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedPresenceStore(BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedPresenceStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
         self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
 
         self._presence_on_startup = self._get_active_presence(db_conn)
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 3655f05e54..eebd5a1fb6 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -15,17 +15,18 @@
 # limitations under the License.
 
 from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
+from synapse.storage.database import Database
 
 from ._slaved_id_tracker import SlavedIdTracker
 from .events import SlavedEventStore
 
 
 class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self._push_rules_stream_id_gen = SlavedIdTracker(
             db_conn, "push_rules_stream", "stream_id"
         )
-        super(SlavedPushRuleStore, self).__init__(db_conn, hs)
+        super(SlavedPushRuleStore, self).__init__(database, db_conn, hs)
 
     def get_push_rules_stream_token(self):
         return (
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index b4331d0799..f22c2d44a3 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -15,14 +15,15 @@
 # limitations under the License.
 
 from synapse.storage.data_stores.main.pusher import PusherWorkerStore
+from synapse.storage.database import Database
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedPusherStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedPusherStore, self).__init__(database, db_conn, hs)
         self._pushers_id_gen = SlavedIdTracker(
             db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
         )
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 43d823c601..d40dc6e1f5 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
+from synapse.storage.database import Database
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
@@ -29,14 +30,14 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         # We instantiate this first as the ReceiptsWorkerStore constructor
         # needs to be able to call get_max_receipt_stream_id
         self._receipts_id_gen = SlavedIdTracker(
             db_conn, "receipts_linearized", "stream_id"
         )
 
-        super(SlavedReceiptsStore, self).__init__(db_conn, hs)
+        super(SlavedReceiptsStore, self).__init__(database, db_conn, hs)
 
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index d9ad386b28..3a20f45316 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -14,14 +14,15 @@
 # limitations under the License.
 
 from synapse.storage.data_stores.main.room import RoomWorkerStore
+from synapse.storage.database import Database
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
 
 class RoomStore(RoomWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(RoomStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomStore, self).__init__(database, db_conn, hs)
         self._public_room_id_gen = SlavedIdTracker(
             db_conn, "public_room_list_stream", "stream_id"
         )
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b7e27d4e97..f9e7f9a71e 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -37,7 +37,7 @@ class SQLBaseStore(object):
     per data store (and not one per physical database).
     """
 
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self.hs = hs
         self._clock = hs.get_clock()
         self.database_engine = hs.database_engine
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index 6adb8adb04..7f5fd81bcf 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -20,6 +20,7 @@ import logging
 import time
 
 from synapse.api.constants import PresenceState
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
     ChainedIdGenerator,
@@ -111,7 +112,7 @@ class DataStore(
     RelationsStore,
     CacheInvalidationStore,
 ):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self.hs = hs
         self._clock = hs.get_clock()
         self.database_engine = hs.database_engine
@@ -169,7 +170,7 @@ class DataStore(
         else:
             self._cache_id_gen = None
 
-        super(DataStore, self).__init__(db_conn, hs)
+        super(DataStore, self).__init__(database, db_conn, hs)
 
         self._presence_on_startup = self._get_active_presence(db_conn)
 
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index a96fe9485c..44d20c19bf 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -22,6 +22,7 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -38,13 +39,13 @@ class AccountDataWorkerStore(SQLBaseStore):
     # the abstract methods being implemented.
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         account_max = self.get_max_account_data_stream_id()
         self._account_data_stream_cache = StreamChangeCache(
             "AccountDataAndTagsChangeCache", account_max
         )
 
-        super(AccountDataWorkerStore, self).__init__(db_conn, hs)
+        super(AccountDataWorkerStore, self).__init__(database, db_conn, hs)
 
     @abc.abstractmethod
     def get_max_account_data_stream_id(self):
@@ -270,12 +271,12 @@ class AccountDataWorkerStore(SQLBaseStore):
 
 
 class AccountDataStore(AccountDataWorkerStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self._account_data_id_gen = StreamIdGenerator(
             db_conn, "account_data_max_stream_id", "stream_id"
         )
 
-        super(AccountDataStore, self).__init__(db_conn, hs)
+        super(AccountDataStore, self).__init__(database, db_conn, hs)
 
     def get_max_account_data_stream_id(self):
         """Get the current max stream id for the private user data stream
diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py
index 6b2e12719c..b2f39649fd 100644
--- a/synapse/storage/data_stores/main/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -24,6 +24,7 @@ from synapse.appservice import AppServiceTransaction
 from synapse.config.appservice import load_appservices
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
 
 logger = logging.getLogger(__name__)
 
@@ -48,13 +49,13 @@ def _make_exclusive_regex(services_cache):
 
 
 class ApplicationServiceWorkerStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self.services_cache = load_appservices(
             hs.hostname, hs.config.app_service_config_files
         )
         self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
 
-        super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs)
+        super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs)
 
     def get_app_services(self):
         return self.services_cache
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 7b470a58f1..320c5b0f07 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -21,6 +21,7 @@ from twisted.internet import defer
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.util.caches import CACHE_SIZE_FACTOR
 from synapse.util.caches.descriptors import Cache
 
@@ -33,8 +34,8 @@ LAST_SEEN_GRANULARITY = 120 * 1000
 
 
 class ClientIpBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         self.db.updates.register_background_index_update(
             "user_ips_device_index",
@@ -363,13 +364,13 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
 
 class ClientIpStore(ClientIpBackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
 
         self.client_ip_last_seen = Cache(
             name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
         )
 
-        super(ClientIpStore, self).__init__(db_conn, hs)
+        super(ClientIpStore, self).__init__(database, db_conn, hs)
 
         self.user_ips_max_age = hs.config.user_ips_max_age
 
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 3c9f09301a..85cfa16850 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -21,6 +21,7 @@ from twisted.internet import defer
 
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
 from synapse.util.caches.expiringcache import ExpiringCache
 
 logger = logging.getLogger(__name__)
@@ -210,8 +211,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
 
-    def __init__(self, db_conn, hs):
-        super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         self.db.updates.register_background_index_update(
             "device_inbox_stream_index",
@@ -241,8 +242,8 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
 class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
 
-    def __init__(self, db_conn, hs):
-        super(DeviceInboxStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(DeviceInboxStore, self).__init__(database, db_conn, hs)
 
         # Map of (user_id, device_id) to the last stream_id that has been
         # deleted up to. This is so that we can no op deletions.
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 91ddaf137e..9a828231c4 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -31,6 +31,7 @@ from synapse.logging.opentracing import (
 )
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.storage.database import Database
 from synapse.types import get_verify_key_from_cross_signing_key
 from synapse.util import batch_iter
 from synapse.util.caches.descriptors import (
@@ -642,8 +643,8 @@ class DeviceWorkerStore(SQLBaseStore):
 
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         self.db.updates.register_background_index_update(
             "device_lists_stream_idx",
@@ -692,8 +693,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
 
 
 class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(DeviceStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(DeviceStore, self).__init__(database, db_conn, hs)
 
         # Map of (user_id, device_id) -> bool. If there is an entry that implies
         # the device exists.
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 31d2e8eb28..1f517e8fad 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -28,6 +28,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
+from synapse.storage.database import Database
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -491,8 +492,8 @@ class EventFederationStore(EventFederationWorkerStore):
 
     EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
 
-    def __init__(self, db_conn, hs):
-        super(EventFederationStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventFederationStore, self).__init__(database, db_conn, hs)
 
         self.db.updates.register_background_update_handler(
             self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index eec054cd48..9988a6d3fc 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -24,6 +24,7 @@ from twisted.internet import defer
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import LoggingTransaction, SQLBaseStore
+from synapse.storage.database import Database
 from synapse.util.caches.descriptors import cachedInlineCallbacks
 
 logger = logging.getLogger(__name__)
@@ -68,8 +69,8 @@ def _deserialize_action(actions, is_highlight):
 
 
 class EventPushActionsWorkerStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(EventPushActionsWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs)
 
         # These get correctly set by _find_stream_orderings_for_times_txn
         self.stream_ordering_month_ago = None
@@ -611,8 +612,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 class EventPushActionsStore(EventPushActionsWorkerStore):
     EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
 
-    def __init__(self, db_conn, hs):
-        super(EventPushActionsStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventPushActionsStore, self).__init__(database, db_conn, hs)
 
         self.db.updates.register_background_index_update(
             self.EPA_HIGHLIGHT_INDEX,
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index d644c82784..da1529f6ea 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -41,6 +41,7 @@ from synapse.storage._base import make_in_list_sql_clause
 from synapse.storage.data_stores.main.event_federation import EventFederationStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.data_stores.main.state import StateGroupWorkerStore
+from synapse.storage.database import Database
 from synapse.types import RoomStreamToken, get_domain_from_id
 from synapse.util import batch_iter
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -95,8 +96,8 @@ def _retry_on_integrity_error(func):
 class EventsStore(
     StateGroupWorkerStore, EventFederationStore, EventsWorkerStore,
 ):
-    def __init__(self, db_conn, hs):
-        super(EventsStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventsStore, self).__init__(database, db_conn, hs)
 
         # Collect metrics on the number of forward extremities that exist.
         # Counter of number of extremities to count
diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index cb1fc30c31..efee17b929 100644
--- a/synapse/storage/data_stores/main/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -23,6 +23,7 @@ from twisted.internet import defer
 
 from synapse.api.constants import EventContentFields
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
 
 logger = logging.getLogger(__name__)
 
@@ -33,8 +34,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
     EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
     DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
 
-    def __init__(self, db_conn, hs):
-        super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs)
 
         self.db.updates.register_background_update_handler(
             self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index e041fc5eac..9ee117ce0f 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -33,6 +33,7 @@ from synapse.events.utils import prune_event
 from synapse.logging.context import LoggingContext, PreserveLoggingContext
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
 from synapse.types import get_domain_from_id
 from synapse.util import batch_iter
 from synapse.util.caches.descriptors import Cache
@@ -55,8 +56,8 @@ _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
 
 
 class EventsWorkerStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(EventsWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventsWorkerStore, self).__init__(database, db_conn, hs)
 
         self._get_event_cache = Cache(
             "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py
index 03c9c6f8ae..80ca36dedf 100644
--- a/synapse/storage/data_stores/main/media_repository.py
+++ b/synapse/storage/data_stores/main/media_repository.py
@@ -13,11 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 
 
 class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(MediaRepositoryBackgroundUpdateStore, self).__init__(
+            database, db_conn, hs
+        )
 
         self.db.updates.register_background_index_update(
             update_name="local_media_repository_url_idx",
@@ -31,8 +34,8 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
 class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
     """Persistence for attachments and avatars"""
 
-    def __init__(self, db_conn, hs):
-        super(MediaRepositoryStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
 
     def get_local_media(self, media_id):
         """Get the metadata for a local piece of media
diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
index 34bf3a1880..27158534cb 100644
--- a/synapse/storage/data_stores/main/monthly_active_users.py
+++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -17,6 +17,7 @@ import logging
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -27,13 +28,13 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
 
 
 class MonthlyActiveUsersStore(SQLBaseStore):
-    def __init__(self, dbconn, hs):
-        super(MonthlyActiveUsersStore, self).__init__(None, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
         self._clock = hs.get_clock()
         self.hs = hs
         # Do not add more reserved users than the total allowable number
         self.db.new_transaction(
-            dbconn,
+            db_conn,
             "initialise_mau_threepids",
             [],
             [],
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index de682cc63a..5ba13aa973 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -27,6 +27,7 @@ from synapse.storage.data_stores.main.appservice import ApplicationServiceWorker
 from synapse.storage.data_stores.main.pusher import PusherWorkerStore
 from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
 from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
+from synapse.storage.database import Database
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -72,8 +73,8 @@ class PushRulesWorkerStore(
     # the abstract methods being implemented.
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, db_conn, hs):
-        super(PushRulesWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
 
         push_rules_prefill, push_rules_id = self.db.get_cache_dict(
             db_conn,
diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py
index ac2d45bd5c..96e54d145e 100644
--- a/synapse/storage/data_stores/main/receipts.py
+++ b/synapse/storage/data_stores/main/receipts.py
@@ -22,6 +22,7 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -38,8 +39,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
     # the abstract methods being implemented.
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, db_conn, hs):
-        super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs)
 
         self._receipts_stream_cache = StreamChangeCache(
             "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
@@ -315,14 +316,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
 
 class ReceiptsStore(ReceiptsWorkerStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         # We instantiate this first as the ReceiptsWorkerStore constructor
         # needs to be able to call get_max_receipt_stream_id
         self._receipts_id_gen = StreamIdGenerator(
             db_conn, "receipts_linearized", "stream_id"
         )
 
-        super(ReceiptsStore, self).__init__(db_conn, hs)
+        super(ReceiptsStore, self).__init__(database, db_conn, hs)
 
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_current_token()
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index 1ef143c6d8..5e8ecac0ea 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -27,6 +27,7 @@ from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.types import UserID
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
@@ -36,8 +37,8 @@ logger = logging.getLogger(__name__)
 
 
 class RegistrationWorkerStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(RegistrationWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RegistrationWorkerStore, self).__init__(database, db_conn, hs)
 
         self.config = hs.config
         self.clock = hs.get_clock()
@@ -794,8 +795,8 @@ class RegistrationWorkerStore(SQLBaseStore):
 
 
 class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
-    def __init__(self, db_conn, hs):
-        super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         self.clock = hs.get_clock()
         self.config = hs.config
@@ -920,8 +921,8 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
 
 
 class RegistrationStore(RegistrationBackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(RegistrationStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RegistrationStore, self).__init__(database, db_conn, hs)
 
         self._account_validity = hs.config.account_validity
 
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index da42dae243..0148be20d3 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -29,6 +29,7 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.search import SearchStore
+from synapse.storage.database import Database
 from synapse.types import ThirdPartyInstanceID
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
@@ -361,8 +362,8 @@ class RoomWorkerStore(SQLBaseStore):
 
 
 class RoomBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(RoomBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         self.config = hs.config
 
@@ -440,8 +441,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
 
 
 class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
-    def __init__(self, db_conn, hs):
-        super(RoomStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomStore, self).__init__(database, db_conn, hs)
 
         self.config = hs.config
 
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 929f6b0d39..92e3b9c512 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -32,6 +32,7 @@ from synapse.storage._base import (
     make_in_list_sql_clause,
 )
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
 from synapse.storage.engines import Sqlite3Engine
 from synapse.storage.roommember import (
     GetRoomsForUserWithStreamOrdering,
@@ -54,8 +55,8 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
 
 
 class RoomMemberWorkerStore(EventsWorkerStore):
-    def __init__(self, db_conn, hs):
-        super(RoomMemberWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs)
 
         # Is the current_state_events.membership up to date? Or is the
         # background update still running?
@@ -835,8 +836,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
 
 class RoomMemberBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs)
         self.db.updates.register_background_update_handler(
             _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
         )
@@ -991,8 +992,8 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
 
 
 class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(RoomMemberStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomMemberStore, self).__init__(database, db_conn, hs)
 
     def _store_room_members_txn(self, txn, events, backfilled):
         """Store a room member in the database.
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index ffa1817e64..4eec2fae5e 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -25,6 +25,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
 logger = logging.getLogger(__name__)
@@ -42,8 +43,8 @@ class SearchBackgroundUpdateStore(SQLBaseStore):
     EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
     EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
 
-    def __init__(self, db_conn, hs):
-        super(SearchBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         if not hs.config.enable_search:
             return
@@ -342,8 +343,8 @@ class SearchBackgroundUpdateStore(SQLBaseStore):
 
 
 class SearchStore(SearchBackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(SearchStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SearchStore, self).__init__(database, db_conn, hs)
 
     def store_event_search_txn(self, txn, event, key, value):
         """Add event to the search table
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 7d5a9f8128..9ef7b48c74 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -28,6 +28,7 @@ from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.state import StateFilter
 from synapse.util.caches import get_cache_factor_for, intern_string
@@ -213,8 +214,8 @@ class StateGroupWorkerStore(
     STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
     CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
 
-    def __init__(self, db_conn, hs):
-        super(StateGroupWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(StateGroupWorkerStore, self).__init__(database, db_conn, hs)
 
         # Originally the state store used a single DictionaryCache to cache the
         # event IDs for the state types in a given state group to avoid hammering
@@ -1029,8 +1030,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
     CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
     EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
 
-    def __init__(self, db_conn, hs):
-        super(StateBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
         self.db.updates.register_background_update_handler(
             self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
             self._background_deduplicate_state,
@@ -1245,8 +1246,8 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
       * `state_groups_state`: Maps state group to state events.
     """
 
-    def __init__(self, db_conn, hs):
-        super(StateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(StateStore, self).__init__(database, db_conn, hs)
 
     def _store_event_state_mappings_txn(
         self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py
index 40579bf965..7bc186e9a1 100644
--- a/synapse/storage/data_stores/main/stats.py
+++ b/synapse/storage/data_stores/main/stats.py
@@ -22,6 +22,7 @@ from twisted.internet.defer import DeferredLock
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.util.caches.descriptors import cached
 
@@ -58,8 +59,8 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
 
 
 class StatsStore(StateDeltasStore):
-    def __init__(self, db_conn, hs):
-        super(StatsStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(StatsStore, self).__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
         self.clock = self.hs.get_clock()
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 2ff8c57109..140da8dad6 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -47,6 +47,7 @@ from twisted.internet import defer
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.types import RoomStreamToken
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -251,8 +252,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, db_conn, hs):
-        super(StreamWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(StreamWorkerStore, self).__init__(database, db_conn, hs)
 
         events_max = self.get_room_max_stream_ordering()
         event_cache_prefill, min_event_val = self.db.get_cache_dict(
diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py
index c0d155a43c..5b07c2fbc0 100644
--- a/synapse/storage/data_stores/main/transactions.py
+++ b/synapse/storage/data_stores/main/transactions.py
@@ -24,6 +24,7 @@ from twisted.internet import defer
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import Database
 from synapse.util.caches.expiringcache import ExpiringCache
 
 # py2 sqlite has buffer hardcoded as only binary type, so we must use it,
@@ -52,8 +53,8 @@ class TransactionStore(SQLBaseStore):
     """A collection of queries for handling PDUs.
     """
 
-    def __init__(self, db_conn, hs):
-        super(TransactionStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(TransactionStore, self).__init__(database, db_conn, hs)
 
         self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
 
diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py
index 62ffb34b29..90c180ec6d 100644
--- a/synapse/storage/data_stores/main/user_directory.py
+++ b/synapse/storage/data_stores/main/user_directory.py
@@ -21,6 +21,7 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes, JoinRules
 from synapse.storage.data_stores.main.state import StateFilter
 from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.types import get_domain_from_id, get_localpart_from_id
 from synapse.util.caches.descriptors import cached
@@ -37,8 +38,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
     # add_users_who_share_private_rooms?
     SHARE_PRIVATE_WORKING_SET = 500
 
-    def __init__(self, db_conn, hs):
-        super(UserDirectoryBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
 
@@ -549,8 +550,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
     # add_users_who_share_private_rooms?
     SHARE_PRIVATE_WORKING_SET = 500
 
-    def __init__(self, db_conn, hs):
-        super(UserDirectoryStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(UserDirectoryStore, self).__init__(database, db_conn, hs)
 
     def remove_from_user_dir(self, user_id):
         def _remove_from_user_dir_txn(txn):
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index dfeea24599..1679112d82 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -28,6 +28,7 @@ from synapse.storage.data_stores.main.appservice import (
     ApplicationServiceStore,
     ApplicationServiceTransactionStore,
 )
+from synapse.storage.database import Database
 
 from tests import unittest
 from tests.utils import setup_test_homeserver
@@ -382,8 +383,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
 # required for ApplicationServiceTransactionStoreTestCase tests
 class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
-    def __init__(self, db_conn, hs):
-        super(TestTransactionStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(TestTransactionStore, self).__init__(database, db_conn, hs)
 
 
 class ApplicationServiceStoreConfigTestCase(unittest.TestCase):