summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/__init__.py28
-rw-r--r--synapse/storage/databases/main/__init__.py7
-rw-r--r--synapse/storage/databases/main/account_data.py7
-rw-r--r--synapse/storage/databases/main/cache.py7
-rw-r--r--synapse/storage/databases/main/deviceinbox.py9
-rw-r--r--synapse/storage/databases/main/devices.py21
-rw-r--r--synapse/storage/databases/main/event_federation.py9
-rw-r--r--synapse/storage/databases/main/event_push_actions.py9
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py7
-rw-r--r--synapse/storage/databases/main/media_repository.py9
-rw-r--r--synapse/storage/databases/main/metrics.py7
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py9
-rw-r--r--synapse/storage/databases/main/push_rule.py7
-rw-r--r--synapse/storage/databases/main/receipts.py7
-rw-r--r--synapse/storage/databases/main/room.py11
-rw-r--r--synapse/storage/databases/main/roommember.py7
-rw-r--r--synapse/storage/databases/main/search.py9
-rw-r--r--synapse/storage/databases/main/state.py11
-rw-r--r--synapse/storage/databases/main/stats.py7
-rw-r--r--synapse/storage/databases/main/transactions.py7
20 files changed, 138 insertions, 57 deletions
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 20b755056b..cfe887b7f7 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -13,33 +13,49 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Generic, List, Optional, Type, TypeVar
 
+from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool, make_conn
 from synapse.storage.databases.main.events import PersistEventsStore
 from synapse.storage.databases.state import StateGroupDataStore
 from synapse.storage.engines import create_engine
 from synapse.storage.prepare_database import prepare_database
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
-class Databases:
+DataStoreT = TypeVar("DataStoreT", bound=SQLBaseStore, covariant=True)
+
+
+class Databases(Generic[DataStoreT]):
     """The various databases.
 
     These are low level interfaces to physical databases.
 
     Attributes:
-        main (DataStore)
+        databases
+        main
+        state
+        persist_events
     """
 
-    def __init__(self, main_store_class, hs):
+    databases: List[DatabasePool]
+    main: DataStoreT
+    state: StateGroupDataStore
+    persist_events: Optional[PersistEventsStore]
+
+    def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"):
         # Note we pass in the main store class here as workers use a different main
         # store.
 
         self.databases = []
-        main = None
-        state = None
-        persist_events = None
+        main: Optional[DataStoreT] = None
+        state: Optional[StateGroupDataStore] = None
+        persist_events: Optional[PersistEventsStore] = None
 
         for database_config in hs.config.database.databases:
             db_name = database_config.name
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 5c21402dea..259cae5b37 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import List, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
 
 from synapse.config.homeserver import HomeServerConfig
 from synapse.storage.database import DatabasePool
@@ -75,6 +75,9 @@ from .ui_auth import UIAuthStore
 from .user_directory import UserDirectoryStore
 from .user_erasure_store import UserErasureStore
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -126,7 +129,7 @@ class DataStore(
     LockStore,
     SessionStore,
 ):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         self.hs = hs
         self._clock = hs.get_clock()
         self.database_engine = database.engine
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 70ca3e09f7..f8bec266ac 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
 
 from synapse.api.constants import AccountDataTypes
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@@ -28,6 +28,9 @@ from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -36,7 +39,7 @@ class AccountDataWorkerStore(SQLBaseStore):
     `get_max_account_data_stream_id` which can be called in the initializer.
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         self._instance_name = hs.get_instance_name()
 
         if isinstance(database.engine, PostgresEngine):
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index c57ae5ef15..36e8422fc6 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -15,7 +15,7 @@
 
 import itertools
 import logging
-from typing import Any, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple
 
 from synapse.api.constants import EventTypes
 from synapse.replication.tcp.streams import BackfillStream, CachesStream
@@ -29,6 +29,9 @@ from synapse.storage.database import DatabasePool
 from synapse.storage.engines import PostgresEngine
 from synapse.util.iterutils import batch_iter
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -38,7 +41,7 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
 
 
 class CacheInvalidationWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 3154906d45..8143168107 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import List, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
 
 from synapse.logging import issue9533_logger
 from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -26,11 +26,14 @@ from synapse.util import json_encoder
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class DeviceInboxWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
@@ -553,7 +556,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 6464520386..a01bf2c5b7 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -15,7 +15,17 @@
 # limitations under the License.
 import abc
 import logging
-from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+)
 
 from synapse.api.errors import Codes, StoreError
 from synapse.logging.opentracing import (
@@ -38,6 +48,9 @@ from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
 from synapse.util.stringutils import shortstr
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
@@ -48,7 +61,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
 
 class DeviceWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.run_background_tasks:
@@ -915,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore):
 
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
@@ -1047,7 +1060,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
 
 
 class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         # Map of (user_id, device_id) -> bool. If there is an entry that implies
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ba9f71a230..ef5d1ef01e 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -14,7 +14,7 @@
 import itertools
 import logging
 from queue import Empty, PriorityQueue
-from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
 
 from prometheus_client import Counter, Gauge
 
@@ -34,6 +34,9 @@ from synapse.util.caches.descriptors import cached
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 oldest_pdu_in_federation_staging = Gauge(
     "synapse_federation_server_oldest_inbound_pdu_in_staging",
     "The age in seconds since we received the oldest pdu in the federation staging area",
@@ -59,7 +62,7 @@ class _NoChainCoverIndex(Exception):
 
 
 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.run_background_tasks:
@@ -1511,7 +1514,7 @@ class EventFederationStore(EventFederationWorkerStore):
 
     EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_update_handler(
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 97b3e92d3f..d957e770dc 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
 
 import attr
 
@@ -23,6 +23,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -64,7 +67,7 @@ def _deserialize_action(actions, is_highlight):
 
 
 class EventPushActionsWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         # These get correctly set by _find_stream_orderings_for_times_txn
@@ -892,7 +895,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 class EventPushActionsStore(EventPushActionsWorkerStore):
     EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 1afc59fafb..fc49112063 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
 
 import attr
 
@@ -26,6 +26,9 @@ from synapse.storage.databases.main.events import PersistEventsStore
 from synapse.storage.types import Cursor
 from synapse.types import JsonDict
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -76,7 +79,7 @@ class _CalculateChainCover:
 
 
 class EventsBackgroundUpdatesStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_update_handler(
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 2fa945d171..717487be28 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -13,11 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from enum import Enum
-from typing import Any, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
     "media_repository_drop_index_wo_method"
 )
@@ -43,7 +46,7 @@ class MediaSortOrder(Enum):
 
 
 class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
@@ -123,7 +126,7 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
 class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
     """Persistence for attachments and avatars"""
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
         self.server_name = hs.hostname
 
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index dac3d14da8..d901933ae4 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -14,7 +14,7 @@
 import calendar
 import logging
 import time
-from typing import Dict
+from typing import TYPE_CHECKING, Dict
 
 from synapse.metrics import GaugeBucketCollector
 from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -24,6 +24,9 @@ from synapse.storage.databases.main.event_push_actions import (
     EventPushActionsWorkerStore,
 )
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 # Collect metrics on the number of forward extremities that exist.
@@ -52,7 +55,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
     stats and prometheus metrics.
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         # Read the extrems every 60 minutes
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index a14ac03d4b..b5284e4f67 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -12,13 +12,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Dict, List, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool, make_in_list_sql_clause
 from synapse.util.caches.descriptors import cached
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 # Number of msec of granularity to store the monthly_active_user timestamp
@@ -27,7 +30,7 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
 
 
 class MonthlyActiveUsersWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
         self._clock = hs.get_clock()
         self.hs = hs
@@ -209,7 +212,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
 
 
 class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self._mau_stats_only = hs.config.server.mau_stats_only
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index fc720f5947..fa782023d4 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import abc
 import logging
-from typing import Dict, List, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Tuple, Union
 
 from synapse.api.errors import NotFoundError, StoreError
 from synapse.push.baserules import list_with_base_rules
@@ -33,6 +33,9 @@ from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -75,7 +78,7 @@ class PushRulesWorkerStore(
     `get_max_push_rules_stream_id` which can be called in the initializer.
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.worker_app is None:
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 01a4281301..c99f8aebdb 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
 
 from twisted.internet import defer
 
@@ -29,11 +29,14 @@ from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class ReceiptsWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         self._instance_name = hs.get_instance_name()
 
         if isinstance(database.engine, PostgresEngine):
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 835d7889cb..f879bbe7c7 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -17,7 +17,7 @@ import collections
 import logging
 from abc import abstractmethod
 from enum import Enum
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
 
 from synapse.api.constants import EventContentFields, EventTypes, JoinRules
 from synapse.api.errors import StoreError
@@ -32,6 +32,9 @@ from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 from synapse.util.stringutils import MXC_REGEX
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -69,7 +72,7 @@ class RoomSortOrder(Enum):
 
 
 class RoomWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.config = hs.config
@@ -1026,7 +1029,7 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
 
 
 class RoomBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.config = hs.config
@@ -1411,7 +1414,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
 
 
 class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.config = hs.config
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index ddb162a4fc..4b288bb2e7 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -53,6 +53,7 @@ from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
 from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
+    from synapse.server import HomeServer
     from synapse.state import _StateCacheEntry
 
 logger = logging.getLogger(__name__)
@@ -63,7 +64,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
 
 
 class RoomMemberWorkerStore(EventsWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         # Used by `_get_joined_hosts` to ensure only one thing mutates the cache
@@ -982,7 +983,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
 
 class RoomMemberBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
         self.db_pool.updates.register_background_update_handler(
             _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
@@ -1132,7 +1133,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
 
 
 class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
     async def forget(self, user_id: str, room_id: str) -> None:
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index c85383c975..7fe233767f 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -15,7 +15,7 @@
 import logging
 import re
 from collections import namedtuple
-from typing import Collection, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
 
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
@@ -24,6 +24,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 SearchEntry = namedtuple(
@@ -102,7 +105,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
     EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
     EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         if not hs.config.server.enable_search:
@@ -355,7 +358,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
 
 class SearchStore(SearchBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
     async def search_msgs(self, room_ids, search_term, keys):
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index a8e8dd4577..fa2c3b1feb 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -15,7 +15,7 @@
 import collections.abc
 import logging
 from collections import namedtuple
-from typing import Iterable, Optional, Set
+from typing import TYPE_CHECKING, Iterable, Optional, Set
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@@ -30,6 +30,9 @@ from synapse.types import StateMap
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedList
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -53,7 +56,7 @@ class _GetStateGroupDelta(
 class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
     """The parts of StateGroupStore that can be called from workers."""
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
     async def get_room_version(self, room_id: str) -> RoomVersion:
@@ -346,7 +349,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
     EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
     DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
@@ -533,5 +536,5 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
       * `state_groups_state`: Maps state group to state events.
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index e20033bb28..5d7b59d861 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -16,7 +16,7 @@
 import logging
 from enum import Enum
 from itertools import chain
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
 
 from typing_extensions import Counter
 
@@ -29,6 +29,9 @@ from synapse.storage.databases.main.state_deltas import StateDeltasStore
 from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 # these fields track absolutes (e.g. total number of rooms on the server)
@@ -93,7 +96,7 @@ class UserSortOrder(Enum):
 
 
 class StatsStore(StateDeltasStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 860146cd1b..d7dc1f73ac 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -14,7 +14,7 @@
 
 import logging
 from collections import namedtuple
-from typing import Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 
 import attr
 from canonicaljson import encode_canonical_json
@@ -26,6 +26,9 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 db_binary_type = memoryview
 
 logger = logging.getLogger(__name__)
@@ -57,7 +60,7 @@ class DestinationRetryTimings:
 
 
 class TransactionWorkerStore(CacheInvalidationWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.run_background_tasks: