summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/8448.misc1
-rwxr-xr-xscripts/synapse_port_db2
-rw-r--r--synapse/app/_base.py5
-rw-r--r--synapse/storage/database.py89
-rw-r--r--synapse/storage/databases/__init__.py2
-rw-r--r--synapse/storage/databases/main/__init__.py1
-rw-r--r--synapse/storage/databases/main/event_push_actions.py8
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py1
-rw-r--r--synapse/storage/databases/main/roommember.py13
-rw-r--r--synapse/storage/databases/main/schema/delta/20/pushers.py19
-rw-r--r--synapse/storage/databases/main/schema/delta/25/fts.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/27/ts.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/30/as_users.py6
-rw-r--r--synapse/storage/databases/main/schema/delta/31/pushers.py19
-rw-r--r--synapse/storage/databases/main/schema/delta/31/search_update.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/33/event_fields.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/33/remote_media_ts.py5
-rw-r--r--synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py7
-rw-r--r--synapse/storage/databases/main/schema/delta/57/local_current_membership.py1
-rw-r--r--synapse/storage/prepare_database.py33
-rw-r--r--synapse/storage/types.py6
-rw-r--r--synapse/storage/util/id_generators.py8
-rw-r--r--synapse/storage/util/sequence.py15
-rw-r--r--tests/storage/test_appservice.py14
-rw-r--r--tests/utils.py2
25 files changed, 152 insertions, 113 deletions
diff --git a/changelog.d/8448.misc b/changelog.d/8448.misc
new file mode 100644
index 0000000000..5ddda1803b
--- /dev/null
+++ b/changelog.d/8448.misc
@@ -0,0 +1 @@
+Add SQL logging on queries that happen during startup.
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index ae2887b7d2..7e12f5440c 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -489,7 +489,7 @@ class Porter(object):
 
         hs = MockHomeserver(self.hs_config)
 
-        with make_conn(db_config, engine) as db_conn:
+        with make_conn(db_config, engine, "portdb") as db_conn:
             engine.check_database(
                 db_conn, allow_outdated_version=allow_outdated_version
             )
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 8bb0b142ca..f6f7b2bf42 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -272,6 +272,11 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
         hs.get_datastore().db_pool.start_profiling()
         hs.get_pusherpool().start()
 
+        # Log when we start the shut down process.
+        hs.get_reactor().addSystemEventTrigger(
+            "before", "shutdown", logger.info, "Shutting down..."
+        )
+
         setup_sentry(hs)
         setup_sdnotify(hs)
 
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 79ec8f119d..0d9d9b7cc0 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -32,6 +32,7 @@ from typing import (
     overload,
 )
 
+import attr
 from prometheus_client import Histogram
 from typing_extensions import Literal
 
@@ -90,13 +91,17 @@ def make_pool(
     return adbapi.ConnectionPool(
         db_config.config["name"],
         cp_reactor=reactor,
-        cp_openfun=engine.on_new_connection,
+        cp_openfun=lambda conn: engine.on_new_connection(
+            LoggingDatabaseConnection(conn, engine, "on_new_connection")
+        ),
         **db_config.config.get("args", {})
     )
 
 
 def make_conn(
-    db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+    db_config: DatabaseConnectionConfig,
+    engine: BaseDatabaseEngine,
+    default_txn_name: str,
 ) -> Connection:
     """Make a new connection to the database and return it.
 
@@ -109,11 +114,60 @@ def make_conn(
         for k, v in db_config.config.get("args", {}).items()
         if not k.startswith("cp_")
     }
-    db_conn = engine.module.connect(**db_params)
+    native_db_conn = engine.module.connect(**db_params)
+    db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name)
+
     engine.on_new_connection(db_conn)
     return db_conn
 
 
+@attr.s(slots=True)
+class LoggingDatabaseConnection:
+    """A wrapper around a database connection that returns `LoggingTransaction`
+    as its cursor class.
+
+    This is mainly used on startup to ensure that queries get logged correctly
+    """
+
+    conn = attr.ib(type=Connection)
+    engine = attr.ib(type=BaseDatabaseEngine)
+    default_txn_name = attr.ib(type=str)
+
+    def cursor(
+        self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
+    ) -> "LoggingTransaction":
+        if not txn_name:
+            txn_name = self.default_txn_name
+
+        return LoggingTransaction(
+            self.conn.cursor(),
+            name=txn_name,
+            database_engine=self.engine,
+            after_callbacks=after_callbacks,
+            exception_callbacks=exception_callbacks,
+        )
+
+    def close(self) -> None:
+        self.conn.close()
+
+    def commit(self) -> None:
+        self.conn.commit()
+
+    def rollback(self, *args, **kwargs) -> None:
+        self.conn.rollback(*args, **kwargs)
+
+    def __enter__(self) -> "Connection":
+        self.conn.__enter__()
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback) -> bool:
+        return self.conn.__exit__(exc_type, exc_value, traceback)
+
+    # Proxy through any unknown lookups to the DB conn class.
+    def __getattr__(self, name):
+        return getattr(self.conn, name)
+
+
 # The type of entry which goes on our after_callbacks and exception_callbacks lists.
 #
 # Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
@@ -247,6 +301,12 @@ class LoggingTransaction:
     def close(self) -> None:
         self.txn.close()
 
+    def __enter__(self) -> "LoggingTransaction":
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        self.close()
+
 
 class PerformanceCounters:
     def __init__(self):
@@ -395,7 +455,7 @@ class DatabasePool:
 
     def new_transaction(
         self,
-        conn: Connection,
+        conn: LoggingDatabaseConnection,
         desc: str,
         after_callbacks: List[_CallbackListEntry],
         exception_callbacks: List[_CallbackListEntry],
@@ -418,12 +478,10 @@ class DatabasePool:
             i = 0
             N = 5
             while True:
-                cursor = LoggingTransaction(
-                    conn.cursor(),
-                    name,
-                    self.engine,
-                    after_callbacks,
-                    exception_callbacks,
+                cursor = conn.cursor(
+                    txn_name=name,
+                    after_callbacks=after_callbacks,
+                    exception_callbacks=exception_callbacks,
                 )
                 try:
                     r = func(cursor, *args, **kwargs)
@@ -584,7 +642,10 @@ class DatabasePool:
                     logger.debug("Reconnecting closed database connection")
                     conn.reconnect()
 
-                return func(conn, *args, **kwargs)
+                db_conn = LoggingDatabaseConnection(
+                    conn, self.engine, "runWithConnection"
+                )
+                return func(db_conn, *args, **kwargs)
 
         return await make_deferred_yieldable(
             self._db_pool.runWithConnection(inner_func, *args, **kwargs)
@@ -1621,7 +1682,7 @@ class DatabasePool:
 
     def get_cache_dict(
         self,
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
         table: str,
         entity_column: str,
         stream_column: str,
@@ -1642,9 +1703,7 @@ class DatabasePool:
             "limit": limit,
         }
 
-        sql = self.engine.convert_param_style(sql)
-
-        txn = db_conn.cursor()
+        txn = db_conn.cursor(txn_name="get_cache_dict")
         txn.execute(sql, (int(max_value),))
 
         cache = {row[0]: int(row[1]) for row in txn}
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index aa5d490624..0c24325011 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -46,7 +46,7 @@ class Databases:
             db_name = database_config.name
             engine = create_engine(database_config.config)
 
-            with make_conn(database_config, engine) as db_conn:
+            with make_conn(database_config, engine, "startup") as db_conn:
                 logger.info("[database config %r]: Checking database server", db_name)
                 engine.check_database(db_conn)
 
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index f823d66709..9b16f45f3e 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -284,7 +284,6 @@ class DataStore(
             " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
             " WHERE state != ?"
         )
-        sql = self.database_engine.convert_param_style(sql)
 
         txn = db_conn.cursor()
         txn.execute(sql, (PresenceState.OFFLINE,))
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 62f1738732..80f3b4d740 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -20,7 +20,7 @@ from typing import Dict, List, Optional, Tuple, Union
 import attr
 
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
+from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
@@ -74,11 +74,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         self.stream_ordering_month_ago = None
         self.stream_ordering_day_ago = None
 
-        cur = LoggingTransaction(
-            db_conn.cursor(),
-            name="_find_stream_orderings_for_times_txn",
-            database_engine=self.database_engine,
-        )
+        cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn")
         self._find_stream_orderings_for_times_txn(cur)
         cur.close()
 
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index b2127598ef..c66f558567 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -214,7 +214,6 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
         self._mau_stats_only = hs.config.mau_stats_only
 
         # Do not add more reserved users than the total allowable number
-        # cur = LoggingTransaction(
         self.db_pool.new_transaction(
             db_conn,
             "initialise_mau_threepids",
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 86ffe2479e..bae1bd22d3 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -21,12 +21,7 @@ from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import (
-    LoggingTransaction,
-    SQLBaseStore,
-    db_to_json,
-    make_in_list_sql_clause,
-)
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import Sqlite3Engine
@@ -60,10 +55,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # background update still running?
         self._current_state_events_membership_up_to_date = False
 
-        txn = LoggingTransaction(
-            db_conn.cursor(),
-            name="_check_safe_current_state_events_membership_updated",
-            database_engine=self.database_engine,
+        txn = db_conn.cursor(
+            txn_name="_check_safe_current_state_events_membership_updated"
         )
         self._check_safe_current_state_events_membership_updated_txn(txn)
         txn.close()
diff --git a/synapse/storage/databases/main/schema/delta/20/pushers.py b/synapse/storage/databases/main/schema/delta/20/pushers.py
index 3edfcfd783..45b846e6a7 100644
--- a/synapse/storage/databases/main/schema/delta/20/pushers.py
+++ b/synapse/storage/databases/main/schema/delta/20/pushers.py
@@ -66,16 +66,15 @@ def run_create(cur, database_engine, *args, **kwargs):
         row[8] = bytes(row[8]).decode("utf-8")
         row[11] = bytes(row[11]).decode("utf-8")
         cur.execute(
-            database_engine.convert_param_style(
-                """
-            INSERT into pushers2 (
-            id, user_name, access_token, profile_tag, kind,
-            app_id, app_display_name, device_display_name,
-            pushkey, ts, lang, data, last_token, last_success,
-            failing_since
-            ) values (%s)"""
-                % (",".join(["?" for _ in range(len(row))]))
-            ),
+            """
+                INSERT into pushers2 (
+                id, user_name, access_token, profile_tag, kind,
+                app_id, app_display_name, device_display_name,
+                pushkey, ts, lang, data, last_token, last_success,
+                failing_since
+                ) values (%s)
+            """
+            % (",".join(["?" for _ in range(len(row))])),
             row,
         )
         count += 1
diff --git a/synapse/storage/databases/main/schema/delta/25/fts.py b/synapse/storage/databases/main/schema/delta/25/fts.py
index ee675e71ff..21f57825d4 100644
--- a/synapse/storage/databases/main/schema/delta/25/fts.py
+++ b/synapse/storage/databases/main/schema/delta/25/fts.py
@@ -71,8 +71,6 @@ def run_create(cur, database_engine, *args, **kwargs):
             " VALUES (?, ?)"
         )
 
-        sql = database_engine.convert_param_style(sql)
-
         cur.execute(sql, ("event_search", progress_json))
 
 
diff --git a/synapse/storage/databases/main/schema/delta/27/ts.py b/synapse/storage/databases/main/schema/delta/27/ts.py
index b7972cfa8e..1c6058063f 100644
--- a/synapse/storage/databases/main/schema/delta/27/ts.py
+++ b/synapse/storage/databases/main/schema/delta/27/ts.py
@@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs):
             " VALUES (?, ?)"
         )
 
-        sql = database_engine.convert_param_style(sql)
-
         cur.execute(sql, ("event_origin_server_ts", progress_json))
 
 
diff --git a/synapse/storage/databases/main/schema/delta/30/as_users.py b/synapse/storage/databases/main/schema/delta/30/as_users.py
index b42c02710a..7f08fabe9f 100644
--- a/synapse/storage/databases/main/schema/delta/30/as_users.py
+++ b/synapse/storage/databases/main/schema/delta/30/as_users.py
@@ -59,9 +59,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
         user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n))
         for chunk in user_chunks:
             cur.execute(
-                database_engine.convert_param_style(
-                    "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
-                    % (",".join("?" for _ in chunk),)
-                ),
+                "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
+                % (",".join("?" for _ in chunk),),
                 [as_id] + chunk,
             )
diff --git a/synapse/storage/databases/main/schema/delta/31/pushers.py b/synapse/storage/databases/main/schema/delta/31/pushers.py
index 9bb504aad5..5be81c806a 100644
--- a/synapse/storage/databases/main/schema/delta/31/pushers.py
+++ b/synapse/storage/databases/main/schema/delta/31/pushers.py
@@ -65,16 +65,15 @@ def run_create(cur, database_engine, *args, **kwargs):
         row = list(row)
         row[12] = token_to_stream_ordering(row[12])
         cur.execute(
-            database_engine.convert_param_style(
-                """
-            INSERT into pushers2 (
-            id, user_name, access_token, profile_tag, kind,
-            app_id, app_display_name, device_display_name,
-            pushkey, ts, lang, data, last_stream_ordering, last_success,
-            failing_since
-            ) values (%s)"""
-                % (",".join(["?" for _ in range(len(row))]))
-            ),
+            """
+                INSERT into pushers2 (
+                id, user_name, access_token, profile_tag, kind,
+                app_id, app_display_name, device_display_name,
+                pushkey, ts, lang, data, last_stream_ordering, last_success,
+                failing_since
+                ) values (%s)
+            """
+            % (",".join(["?" for _ in range(len(row))])),
             row,
         )
         count += 1
diff --git a/synapse/storage/databases/main/schema/delta/31/search_update.py b/synapse/storage/databases/main/schema/delta/31/search_update.py
index 63b757ade6..b84c844e3a 100644
--- a/synapse/storage/databases/main/schema/delta/31/search_update.py
+++ b/synapse/storage/databases/main/schema/delta/31/search_update.py
@@ -55,8 +55,6 @@ def run_create(cur, database_engine, *args, **kwargs):
             " VALUES (?, ?)"
         )
 
-        sql = database_engine.convert_param_style(sql)
-
         cur.execute(sql, ("event_search_order", progress_json))
 
 
diff --git a/synapse/storage/databases/main/schema/delta/33/event_fields.py b/synapse/storage/databases/main/schema/delta/33/event_fields.py
index a3e81eeac7..e928c66a8f 100644
--- a/synapse/storage/databases/main/schema/delta/33/event_fields.py
+++ b/synapse/storage/databases/main/schema/delta/33/event_fields.py
@@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs):
             " VALUES (?, ?)"
         )
 
-        sql = database_engine.convert_param_style(sql)
-
         cur.execute(sql, ("event_fields_sender_url", progress_json))
 
 
diff --git a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
index a26057dfb6..ad875c733a 100644
--- a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
+++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
@@ -23,8 +23,5 @@ def run_create(cur, database_engine, *args, **kwargs):
 
 def run_upgrade(cur, database_engine, *args, **kwargs):
     cur.execute(
-        database_engine.convert_param_style(
-            "UPDATE remote_media_cache SET last_access_ts = ?"
-        ),
-        (int(time.time() * 1000),),
+        "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),),
     )
diff --git a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
index 1de8b54961..bb7296852a 100644
--- a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
+++ b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
@@ -1,6 +1,8 @@
 import logging
+from io import StringIO
 
 from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import execute_statements_from_stream
 
 logger = logging.getLogger(__name__)
 
@@ -46,7 +48,4 @@ def run_create(cur, database_engine, *args, **kwargs):
         select_clause,
     )
 
-    if isinstance(database_engine, PostgresEngine):
-        cur.execute(sql)
-    else:
-        cur.executescript(sql)
+    execute_statements_from_stream(cur, StringIO(sql))
diff --git a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
index 63b5acdcf7..44917f0a2e 100644
--- a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
+++ b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
@@ -68,7 +68,6 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
                 INNER JOIN room_memberships AS r USING (event_id)
                 WHERE type = 'm.room.member' AND state_key LIKE ?
         """
-    sql = database_engine.convert_param_style(sql)
     cur.execute(sql, ("%:" + config.server_name,))
 
     cur.execute(
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 4957e77f4c..459754feab 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -13,7 +13,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import imp
 import logging
 import os
@@ -24,9 +23,10 @@ from typing import Optional, TextIO
 import attr
 
 from synapse.config.homeserver import HomeServerConfig
+from synapse.storage.database import LoggingDatabaseConnection
 from synapse.storage.engines import BaseDatabaseEngine
 from synapse.storage.engines.postgres import PostgresEngine
-from synapse.storage.types import Connection, Cursor
+from synapse.storage.types import Cursor
 from synapse.types import Collection
 
 logger = logging.getLogger(__name__)
@@ -67,7 +67,7 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = (
 
 
 def prepare_database(
-    db_conn: Connection,
+    db_conn: LoggingDatabaseConnection,
     database_engine: BaseDatabaseEngine,
     config: Optional[HomeServerConfig],
     databases: Collection[str] = ["main", "state"],
@@ -89,7 +89,7 @@ def prepare_database(
     """
 
     try:
-        cur = db_conn.cursor()
+        cur = db_conn.cursor(txn_name="prepare_database")
 
         # sqlite does not automatically start transactions for DDL / SELECT statements,
         # so we start one before running anything. This ensures that any upgrades
@@ -258,9 +258,7 @@ def _setup_new_database(cur, database_engine, databases):
             executescript(cur, entry.absolute_path)
 
     cur.execute(
-        database_engine.convert_param_style(
-            "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
-        ),
+        "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
         (max_current_ver, False),
     )
 
@@ -486,17 +484,13 @@ def _upgrade_existing_database(
 
             # Mark as done.
             cur.execute(
-                database_engine.convert_param_style(
-                    "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)"
-                ),
+                "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)",
                 (v, relative_path),
             )
 
             cur.execute("DELETE FROM schema_version")
             cur.execute(
-                database_engine.convert_param_style(
-                    "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
-                ),
+                "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
                 (v, True),
             )
 
@@ -532,10 +526,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
             schemas to be applied
     """
     cur.execute(
-        database_engine.convert_param_style(
-            "SELECT file FROM applied_module_schemas WHERE module_name = ?"
-        ),
-        (modname,),
+        "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
     )
     applied_deltas = {d for d, in cur}
     for (name, stream) in names_and_streams:
@@ -553,9 +544,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
 
         # Mark as done.
         cur.execute(
-            database_engine.convert_param_style(
-                "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)"
-            ),
+            "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)",
             (modname, name),
         )
 
@@ -627,9 +616,7 @@ def _get_or_create_schema_state(txn, database_engine):
 
     if current_version:
         txn.execute(
-            database_engine.convert_param_style(
-                "SELECT file FROM applied_schema_deltas WHERE version >= ?"
-            ),
+            "SELECT file FROM applied_schema_deltas WHERE version >= ?",
             (current_version,),
         )
         applied_deltas = [d for d, in txn]
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 2d2b560e74..970bb1b9da 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -61,3 +61,9 @@ class Connection(Protocol):
 
     def rollback(self, *args, **kwargs) -> None:
         ...
+
+    def __enter__(self) -> "Connection":
+        ...
+
+    def __exit__(self, exc_type, exc_value, traceback) -> bool:
+        ...
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index c92cd4a6ba..51f680d05d 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -54,7 +54,7 @@ def _load_current_id(db_conn, table, column, step=1):
     """
     # debug logging for https://github.com/matrix-org/synapse/issues/7968
     logger.info("initialising stream generator for %s(%s)", table, column)
-    cur = db_conn.cursor()
+    cur = db_conn.cursor(txn_name="_load_current_id")
     if step == 1:
         cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
     else:
@@ -269,7 +269,7 @@ class MultiWriterIdGenerator:
     def _load_current_ids(
         self, db_conn, table: str, instance_column: str, id_column: str
     ):
-        cur = db_conn.cursor()
+        cur = db_conn.cursor(txn_name="_load_current_ids")
 
         # Load the current positions of all writers for the stream.
         if self._writers:
@@ -283,15 +283,12 @@ class MultiWriterIdGenerator:
                     stream_name = ?
                     AND instance_name != ALL(?)
             """
-            sql = self._db.engine.convert_param_style(sql)
             cur.execute(sql, (self._stream_name, self._writers))
 
             sql = """
                 SELECT instance_name, stream_id FROM stream_positions
                 WHERE stream_name = ?
             """
-            sql = self._db.engine.convert_param_style(sql)
-
             cur.execute(sql, (self._stream_name,))
 
             self._current_positions = {
@@ -340,7 +337,6 @@ class MultiWriterIdGenerator:
                 "instance": instance_column,
                 "cmp": "<=" if self._positive else ">=",
             }
-            sql = self._db.engine.convert_param_style(sql)
             cur.execute(sql, (min_stream_id * self._return_factor,))
 
             self._persisted_upto_position = min_stream_id
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 2dd95e2709..ff2d038ad2 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -17,6 +17,7 @@ import logging
 import threading
 from typing import Callable, List, Optional
 
+from synapse.storage.database import LoggingDatabaseConnection
 from synapse.storage.engines import (
     BaseDatabaseEngine,
     IncorrectDatabaseSetup,
@@ -53,7 +54,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
 
     @abc.abstractmethod
     def check_consistency(
-        self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+        self,
+        db_conn: LoggingDatabaseConnection,
+        table: str,
+        id_column: str,
+        positive: bool = True,
     ):
         """Should be called during start up to test that the current value of
         the sequence is greater than or equal to the maximum ID in the table.
@@ -82,9 +87,13 @@ class PostgresSequenceGenerator(SequenceGenerator):
         return [i for (i,) in txn]
 
     def check_consistency(
-        self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+        self,
+        db_conn: LoggingDatabaseConnection,
+        table: str,
+        id_column: str,
+        positive: bool = True,
     ):
-        txn = db_conn.cursor()
+        txn = db_conn.cursor(txn_name="sequence.check_consistency")
 
         # First we get the current max ID from the table.
         table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % {
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 46f94914ff..c905a38930 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -58,7 +58,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
         # must be done after inserts
         database = hs.get_datastores().databases[0]
         self.store = ApplicationServiceStore(
-            database, make_conn(database._database_config, database.engine), hs
+            database, make_conn(database._database_config, database.engine, "test"), hs
         )
 
     def tearDown(self):
@@ -132,7 +132,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
 
         db_config = hs.config.get_single_database()
         self.store = TestTransactionStore(
-            database, make_conn(db_config, self.engine), hs
+            database, make_conn(db_config, self.engine, "test"), hs
         )
 
     def _add_service(self, url, as_token, id):
@@ -448,7 +448,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
 
         database = hs.get_datastores().databases[0]
         ApplicationServiceStore(
-            database, make_conn(database._database_config, database.engine), hs
+            database, make_conn(database._database_config, database.engine, "test"), hs
         )
 
     @defer.inlineCallbacks
@@ -467,7 +467,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         with self.assertRaises(ConfigError) as cm:
             database = hs.get_datastores().databases[0]
             ApplicationServiceStore(
-                database, make_conn(database._database_config, database.engine), hs
+                database,
+                make_conn(database._database_config, database.engine, "test"),
+                hs,
             )
 
         e = cm.exception
@@ -491,7 +493,9 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         with self.assertRaises(ConfigError) as cm:
             database = hs.get_datastores().databases[0]
             ApplicationServiceStore(
-                database, make_conn(database._database_config, database.engine), hs
+                database,
+                make_conn(database._database_config, database.engine, "test"),
+                hs,
             )
 
         e = cm.exception
diff --git a/tests/utils.py b/tests/utils.py
index 7a927c7f74..af563ffe0f 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -38,6 +38,7 @@ from synapse.http.server import HttpServer
 from synapse.logging.context import current_context, set_current_context
 from synapse.server import HomeServer
 from synapse.storage import DataStore
+from synapse.storage.database import LoggingDatabaseConnection
 from synapse.storage.engines import PostgresEngine, create_engine
 from synapse.storage.prepare_database import prepare_database
 from synapse.util.ratelimitutils import FederationRateLimiter
@@ -88,6 +89,7 @@ def setupdb():
             host=POSTGRES_HOST,
             password=POSTGRES_PASSWORD,
         )
+        db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
         prepare_database(db_conn, db_engine, None)
         db_conn.close()