summary refs log tree commit diff
diff options
context:
space:
mode:
authorSean Quah <8349537+squahtx@users.noreply.github.com>2021-12-13 17:05:00 +0000
committerGitHub <noreply@github.com>2021-12-13 17:05:00 +0000
commit5305a5e88144828419249fd9e4c5198d92276a44 (patch)
tree6c03eceaef4ae259d52510d64d1b9e90018483d1
parentAdd type hints to `synapse/storage/databases/main/end_to_end_keys.py` (#11551) (diff)
downloadsynapse-5305a5e88144828419249fd9e4c5198d92276a44.tar.xz
Type hint the constructors of the data store classes (#11555)
-rw-r--r--changelog.d/11555.misc1
-rw-r--r--synapse/replication/slave/storage/_base.py9
-rw-r--r--synapse/replication/slave/storage/client_ips.py9
-rw-r--r--synapse/replication/slave/storage/devices.py9
-rw-r--r--synapse/replication/slave/storage/events.py9
-rw-r--r--synapse/replication/slave/storage/filtering.py9
-rw-r--r--synapse/replication/slave/storage/groups.py9
-rw-r--r--synapse/storage/_base.py13
-rw-r--r--synapse/storage/database.py2
-rw-r--r--synapse/storage/databases/main/__init__.py9
-rw-r--r--synapse/storage/databases/main/appservice.py10
-rw-r--r--synapse/storage/databases/main/cache.py9
-rw-r--r--synapse/storage/databases/main/censor_events.py13
-rw-r--r--synapse/storage/databases/main/client_ips.py22
-rw-r--r--synapse/storage/databases/main/deviceinbox.py7
-rw-r--r--synapse/storage/databases/main/devices.py22
-rw-r--r--synapse/storage/databases/main/event_federation.py20
-rw-r--r--synapse/storage/databases/main/event_push_actions.py20
-rw-r--r--synapse/storage/databases/main/events.py9
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py8
-rw-r--r--synapse/storage/databases/main/group_server.py10
-rw-r--r--synapse/storage/databases/main/lock.py14
-rw-r--r--synapse/storage/databases/main/metrics.py9
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py20
-rw-r--r--synapse/storage/databases/main/presence.py6
-rw-r--r--synapse/storage/databases/main/push_rule.py9
-rw-r--r--synapse/storage/databases/main/receipts.py13
-rw-r--r--synapse/storage/databases/main/room.py27
-rw-r--r--synapse/storage/databases/main/roommember.py23
-rw-r--r--synapse/storage/databases/main/search.py20
-rw-r--r--synapse/storage/databases/main/state.py27
-rw-r--r--synapse/storage/databases/main/stats.py9
-rw-r--r--synapse/storage/databases/main/stream.py8
-rw-r--r--synapse/storage/databases/main/transactions.py13
-rw-r--r--synapse/storage/databases/main/user_directory.py11
35 files changed, 351 insertions, 87 deletions
diff --git a/changelog.d/11555.misc b/changelog.d/11555.misc
new file mode 100644
index 0000000000..d451940bf2
--- /dev/null
+++ b/changelog.d/11555.misc
@@ -0,0 +1 @@
+Add missing type hints to storage classes.
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 7ecb446e7c..7644146dba 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -15,7 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, Optional
 
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
@@ -27,7 +27,12 @@ logger = logging.getLogger(__name__)
 
 
 class BaseSlavedStore(CacheInvalidationWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
             self._cache_id_gen: Optional[
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 61cd7e5228..bc888ce1a8 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -14,7 +14,7 @@
 
 from typing import TYPE_CHECKING
 
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
 from synapse.util.caches.lrucache import LruCache
 
@@ -25,7 +25,12 @@ if TYPE_CHECKING:
 
 
 class SlavedClientIpStore(BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.client_ip_last_seen: LruCache[tuple, int] = LruCache(
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 0a58296089..a2aff75b70 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.devices import DeviceWorkerStore
 from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -27,7 +27,12 @@ if TYPE_CHECKING:
 
 
 class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.hs = hs
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 63ed50caa5..50e7379e83 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -15,7 +15,7 @@
 import logging
 from typing import TYPE_CHECKING
 
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
 from synapse.storage.databases.main.event_push_actions import (
     EventPushActionsWorkerStore,
@@ -58,7 +58,12 @@ class SlavedEventStore(
     RelationsWorkerStore,
     BaseSlavedStore,
 ):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         events_max = self._stream_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index 90284c202d..4d185e2b56 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -14,7 +14,7 @@
 
 from typing import TYPE_CHECKING
 
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.filtering import FilteringStore
 
 from ._base import BaseSlavedStore
@@ -24,7 +24,12 @@ if TYPE_CHECKING:
 
 
 class SlavedFilteringStore(BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
     # Filters are immutable so this cache doesn't need to be expired
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 497e16c69e..9d90e26375 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import GroupServerStream
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.group_server import GroupServerWorkerStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -26,7 +26,12 @@ if TYPE_CHECKING:
 
 
 class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.hs = hs
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 3056e64ff5..7967011afd 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -17,10 +17,8 @@ import logging
 from abc import ABCMeta
 from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
 
-from synapse.storage.database import LoggingTransaction  # noqa: F401
-from synapse.storage.database import make_in_list_sql_clause  # noqa: F401
-from synapse.storage.database import DatabasePool
-from synapse.storage.types import Connection
+from synapse.storage.database import make_in_list_sql_clause  # noqa: F401; noqa: F401
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.types import get_domain_from_id
 from synapse.util import json_decoder
 
@@ -38,7 +36,12 @@ class SQLBaseStore(metaclass=ABCMeta):
     per data store (and not one per physical database).
     """
 
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         self.hs = hs
         self._clock = hs.get_clock()
         self.database_engine = database.engine
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 5552dd3c5c..3b44e6469c 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -175,7 +175,7 @@ class LoggingDatabaseConnection:
     def rollback(self) -> None:
         self.conn.rollback()
 
-    def __enter__(self) -> "Connection":
+    def __enter__(self) -> "LoggingDatabaseConnection":
         self.conn.__enter__()
         return self
 
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 065145c0d2..716b25dd34 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -18,7 +18,7 @@ import logging
 from typing import TYPE_CHECKING, List, Optional, Tuple
 
 from synapse.config.homeserver import HomeServerConfig
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.stats import UserSortOrder
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
@@ -129,7 +129,12 @@ class DataStore(
     LockStore,
     SessionStore,
 ):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         self.hs = hs
         self._clock = hs.get_clock()
         self.database_engine = database.engine
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 4a883dc166..92c95a41d7 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -24,9 +24,8 @@ from synapse.appservice import (
 from synapse.config.appservice import load_appservices
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
-from synapse.storage.types import Connection
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 
@@ -58,7 +57,12 @@ def _make_exclusive_regex(
 
 
 class ApplicationServiceWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         self.services_cache = load_appservices(
             hs.hostname, hs.config.appservice.app_service_config_files
         )
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 36e8422fc6..0024348067 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -25,7 +25,7 @@ from synapse.replication.tcp.streams.events import (
     EventsStreamEventRow,
 )
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.engines import PostgresEngine
 from synapse.util.iterutils import batch_iter
 
@@ -41,7 +41,12 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
 
 
 class CacheInvalidationWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 0f56e10220..fd3fc298b3 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -18,7 +18,11 @@ from typing import TYPE_CHECKING, Optional
 from synapse.events.utils import prune_event_dict
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.util import json_encoder
@@ -31,7 +35,12 @@ logger = logging.getLogger(__name__)
 
 
 class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         if (
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 1dc7f0ebe3..8b0c614ece 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -26,7 +26,6 @@ from synapse.storage.database import (
     make_tuple_comparison_clause,
 )
 from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
-from synapse.storage.types import Connection
 from synapse.types import JsonDict, UserID
 from synapse.util.caches.lrucache import LruCache
 
@@ -65,7 +64,12 @@ class LastConnectionInfo(TypedDict):
 
 
 class ClientIpBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
@@ -394,7 +398,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
 
 class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.user_ips_max_age = hs.config.server.user_ips_max_age
@@ -532,7 +541,12 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
 
 
 class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
 
         # (user_id, access_token, ip,) -> last_seen
         self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index ab8766c75b..b410eefdc7 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -601,7 +601,12 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
     REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index eff825dd22..3932599988 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -38,6 +38,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
+    LoggingDatabaseConnection,
     LoggingTransaction,
     make_tuple_comparison_clause,
 )
@@ -61,7 +62,12 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
 
 class DeviceWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.run_background_tasks:
@@ -953,7 +959,12 @@ class DeviceWorkerStore(SQLBaseStore):
 
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
@@ -1085,7 +1096,12 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
 
 
 class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         # Map of (user_id, device_id) -> bool. If there is an entry that implies
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 9580a40785..2287f1cc68 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -24,7 +24,11 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion
 from synapse.events import EventBase, make_event_from_dict
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.signatures import SignatureWorkerStore
 from synapse.storage.engines import PostgresEngine
@@ -62,7 +66,12 @@ class _NoChainCoverIndex(Exception):
 
 
 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.run_background_tasks:
@@ -1514,7 +1523,12 @@ class EventFederationStore(EventFederationWorkerStore):
 
     EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_update_handler(
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 3efdd0c920..eacff3e432 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -20,7 +20,11 @@ from typing_extensions import TypedDict
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
@@ -82,7 +86,12 @@ def _deserialize_action(actions, is_highlight):
 
 
 class EventPushActionsWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         # These get correctly set by _find_stream_orderings_for_times_txn
@@ -910,7 +919,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 class EventPushActionsStore(EventPushActionsWorkerStore):
     EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5184e6bf85..81e67ece55 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -41,10 +41,13 @@ from synapse.events import EventBase  # noqa: F401
 from synapse.events.snapshot import EventContext  # noqa: F401
 from synapse.logging.utils import log_function
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.events_worker import EventCacheEntry
 from synapse.storage.databases.main.search import SearchEntry
-from synapse.storage.types import Connection
 from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
 from synapse.types import StateMap, get_domain_from_id
@@ -95,7 +98,7 @@ class PersistEventsStore:
         hs: "HomeServer",
         db: DatabasePool,
         main_data_store: "DataStore",
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
     ):
         self.hs = hs
         self.db_pool = db
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index c88fd35e7f..9b36941fec 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -23,6 +23,7 @@ from synapse.events import make_event_from_dict
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
+    LoggingDatabaseConnection,
     LoggingTransaction,
     make_tuple_comparison_clause,
 )
@@ -83,7 +84,12 @@ class _CalculateChainCover:
 
 
 class EventsBackgroundUpdatesStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_update_handler(
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index bb621df0dd..3f6086050b 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -19,8 +19,7 @@ from typing_extensions import TypedDict
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
-from synapse.storage.types import Connection
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 
@@ -40,7 +39,12 @@ class _RoomInGroup(TypedDict):
 
 
 class GroupServerWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         database.updates.register_background_index_update(
             update_name="local_group_updates_index",
             index_name="local_group_updates_stream_id_index",
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index a540f7fb26..bedacaf0d7 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -20,8 +20,11 @@ from twisted.internet.interfaces import IReactorCore
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingTransaction
-from synapse.storage.types import Connection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
@@ -54,7 +57,12 @@ class LockStore(SQLBaseStore):
     `last_renewed_ts` column with the current time.
     """
 
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self._reactor = hs.get_reactor()
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index d901933ae4..3bb21958d1 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Dict
 from synapse.metrics import GaugeBucketCollector
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.event_push_actions import (
     EventPushActionsWorkerStore,
 )
@@ -55,7 +55,12 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
     stats and prometheus metrics.
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         # Read the extrems every 60 minutes
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 3c98ef876f..65b7e307e1 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -16,7 +16,11 @@ from typing import TYPE_CHECKING, Dict, List, Optional
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    make_in_list_sql_clause,
+)
 from synapse.util.caches.descriptors import cached
 from synapse.util.threepids import canonicalise_email
 
@@ -31,7 +35,12 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
 
 
 class MonthlyActiveUsersWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
         self._clock = hs.get_clock()
         self.hs = hs
@@ -213,7 +222,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
 
 
 class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self._mau_stats_only = hs.config.server.mau_stats_only
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index cc0eebdb46..02d534ae45 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
 from synapse.api.presence import PresenceState, UserPresenceState
 from synapse.replication.tcp.streams import PresenceStream
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Connection
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
@@ -33,7 +33,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore):
     def __init__(
         self,
         database: DatabasePool,
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
     ):
         super().__init__(database, db_conn, hs)
@@ -52,7 +52,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
     def __init__(
         self,
         database: DatabasePool,
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
     ):
         super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 3b63267395..e01c94930a 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -20,7 +20,7 @@ from synapse.api.errors import NotFoundError, StoreError
 from synapse.push.baserules import list_with_base_rules
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.pusher import PusherWorkerStore
@@ -81,7 +81,12 @@ class PushRulesWorkerStore(
     `get_max_push_rules_stream_id` which can be called in the initializer.
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.worker_app is None:
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 9c5625c8bb..bf0b903af2 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -32,7 +32,11 @@ from synapse.api.constants import ReceiptTypes
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import ReceiptsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.types import JsonDict
@@ -47,7 +51,12 @@ logger = logging.getLogger(__name__)
 
 
 class ReceiptsWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         self._instance_name = hs.get_instance_name()
 
         if isinstance(database.engine, PostgresEngine):
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 7d694d852d..28c4b65bbd 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -24,7 +24,11 @@ from synapse.api.errors import StoreError
 from synapse.api.room_versions import RoomVersion, RoomVersions
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.search import SearchStore
 from synapse.storage.types import Cursor
 from synapse.types import JsonDict, ThirdPartyInstanceID
@@ -72,7 +76,12 @@ class RoomSortOrder(Enum):
 
 
 class RoomWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.config = hs.config
@@ -1050,7 +1059,12 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
 
 
 class RoomBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.config = hs.config
@@ -1435,7 +1449,12 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
 
 
 class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.config = hs.config
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 6b2a8d06a6..cda80d6511 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -37,7 +37,7 @@ from synapse.metrics.background_process_metrics import (
     wrap_as_background_process,
 )
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import Sqlite3Engine
 from synapse.storage.roommember import (
@@ -64,7 +64,12 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
 
 
 class RoomMemberWorkerStore(EventsWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         # Used by `_get_joined_hosts` to ensure only one thing mutates the cache
@@ -985,7 +990,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
 
 class RoomMemberBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
         self.db_pool.updates.register_background_update_handler(
             _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
@@ -1135,7 +1145,12 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
 
 
 class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
     async def forget(self, user_id: str, room_id: str) -> None:
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 7fe233767f..f87acfb866 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -20,7 +20,11 @@ from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
@@ -105,7 +109,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
     EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
     EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         if not hs.config.server.enable_search:
@@ -358,7 +367,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
 
 class SearchStore(SearchBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
     async def search_msgs(self, room_ids, search_term, keys):
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index fa2c3b1feb..4bc044fb16 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -22,7 +22,11 @@ from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.state import StateFilter
@@ -56,7 +60,12 @@ class _GetStateGroupDelta(
 class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
     """The parts of StateGroupStore that can be called from workers."""
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
     async def get_room_version(self, room_id: str) -> RoomVersion:
@@ -349,7 +358,12 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
     EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
     DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
@@ -536,5 +550,10 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
       * `state_groups_state`: Maps state group to state events.
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 5d7b59d861..9020e0976c 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -24,7 +24,7 @@ from twisted.internet.defer import DeferredLock
 
 from synapse.api.constants import EventContentFields, EventTypes, Membership
 from synapse.api.errors import StoreError
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.state_deltas import StateDeltasStore
 from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
@@ -96,7 +96,12 @@ class UserSortOrder(Enum):
 
 
 class StatsStore(StateDeltasStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 57aab55259..9488fd5094 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -49,6 +49,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
+    LoggingDatabaseConnection,
     LoggingTransaction,
     make_in_list_sql_clause,
 )
@@ -339,7 +340,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
     which can be called in the initializer.
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 1622822552..54b41513ee 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -22,7 +22,11 @@ from canonicaljson import encode_canonical_json
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import db_to_json
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
@@ -71,7 +75,12 @@ class DestinationRetryTimings:
 
 
 class TransactionWorkerStore(CacheInvalidationWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         if hs.config.worker.run_background_tasks:
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index e98a45b6af..0f9b8575d3 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -32,11 +32,14 @@ if TYPE_CHECKING:
     from synapse.server import HomeServer
 
 from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.state import StateFilter
 from synapse.storage.databases.main.state_deltas import StateDeltasStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.storage.types import Connection
 from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
 from synapse.util.caches.descriptors import cached
 
@@ -53,7 +56,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
     def __init__(
         self,
         database: DatabasePool,
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
     ):
         super().__init__(database, db_conn, hs)
@@ -592,7 +595,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
     def __init__(
         self,
         database: DatabasePool,
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
     ) -> None:
         super().__init__(database, db_conn, hs)