summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/background_updates.py11
-rw-r--r--synapse/storage/databases/main/purge_events.py9
-rw-r--r--synapse/storage/engines/_base.py13
-rw-r--r--synapse/storage/engines/postgres.py20
-rw-r--r--synapse/storage/engines/sqlite.py6
5 files changed, 45 insertions, 14 deletions
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 62fbd05534..949840e63f 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -768,8 +768,9 @@ class BackgroundUpdater:
 
                 # override the global statement timeout to avoid accidentally squashing
                 # a long-running index creation process
-                timeout_sql = "SET SESSION statement_timeout = 0"
-                c.execute(timeout_sql)
+                self.db_pool.engine.attempt_to_set_statement_timeout(
+                    c, 0, for_transaction=True
+                )
 
                 sql = (
                     "CREATE %(unique)s INDEX CONCURRENTLY %(name)s"
@@ -791,12 +792,6 @@ class BackgroundUpdater:
                     logger.debug("[SQL] %s", sql)
                     c.execute(sql)
             finally:
-                # mypy ignore - `statement_timeout` is defined on PostgresEngine
-                # reset the global timeout to the default
-                default_timeout = self.db_pool.engine.statement_timeout  # type: ignore[attr-defined]
-                undo_timeout_sql = f"SET statement_timeout = {default_timeout}"
-                conn.cursor().execute(undo_timeout_sql)
-
                 conn.engine.attempt_to_set_autocommit(conn.conn, False)
 
         def create_index_sqlite(conn: "LoggingDatabaseConnection") -> None:
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 1a5b5731bb..56c8198149 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -89,10 +89,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         # furthermore, we might already have the table from a previous (failed)
         # purge attempt, so let's drop the table first.
 
-        if isinstance(self.database_engine, PostgresEngine):
-            # Disable statement timeouts for this transaction; purging rooms can
-            # take a while!
-            txn.execute("SET LOCAL statement_timeout = 0")
+        # Disable statement timeouts for this transaction; purging rooms can
+        # take a while!
+        self.database_engine.attempt_to_set_statement_timeout(
+            txn, 0, for_transaction=True
+        )
 
         txn.execute("DROP TABLE IF EXISTS events_to_purge")
 
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index b1a2418cbd..888b4a5660 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -36,6 +36,9 @@ CursorType = TypeVar("CursorType", bound=Cursor)
 
 
 class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCMeta):
+    # The default statement timeout to use for transactions.
+    statement_timeout: Optional[int] = None
+
     def __init__(self, module: DBAPI2Module, config: Mapping[str, Any]):
         self.module = module
 
@@ -132,6 +135,16 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM
         """
         ...
 
+    @abc.abstractmethod
+    def attempt_to_set_statement_timeout(
+        self, cursor: CursorType, statement_timeout: int, for_transaction: bool
+    ) -> None:
+        """Attempt to set the cursor's statement timeout.
+
+        Note this has no effect on SQLite3.
+        """
+        ...
+
     @staticmethod
     @abc.abstractmethod
     def executescript(cursor: CursorType, script: str) -> None:
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index ec4c4041b7..6ce9ef5fcd 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -52,7 +52,7 @@ class PostgresEngine(
         # some degenerate query plan has been created and the client has probably
         # timed out/walked off anyway.
         # This is in milliseconds.
-        self.statement_timeout: Optional[int] = database_config.get(
+        self.statement_timeout = database_config.get(
             "statement_timeout", 60 * 60 * 1000
         )
         self._version: Optional[int] = None  # unknown as yet
@@ -169,7 +169,11 @@ class PostgresEngine(
 
         # Abort really long-running statements and turn them into errors.
         if self.statement_timeout is not None:
-            cursor.execute("SET statement_timeout TO ?", (self.statement_timeout,))
+            self.attempt_to_set_statement_timeout(
+                cast(psycopg2.extensions.cursor, cursor.txn),
+                self.statement_timeout,
+                for_transaction=False,
+            )
 
         cursor.close()
         db_conn.commit()
@@ -233,6 +237,18 @@ class PostgresEngine(
             isolation_level = self.isolation_level_map[isolation_level]
         return conn.set_isolation_level(isolation_level)
 
+    def attempt_to_set_statement_timeout(
+        self,
+        cursor: psycopg2.extensions.cursor,
+        statement_timeout: int,
+        for_transaction: bool,
+    ) -> None:
+        if for_transaction:
+            sql = "SET LOCAL statement_timeout TO ?"
+        else:
+            sql = "SET statement_timeout TO ?"
+        cursor.execute(sql, (statement_timeout,))
+
     @staticmethod
     def executescript(cursor: psycopg2.extensions.cursor, script: str) -> None:
         """Execute a chunk of SQL containing multiple semicolon-delimited statements.
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 802069e1e1..64d2a72ed5 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -143,6 +143,12 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]):
         # All transactions are SERIALIZABLE by default in sqlite
         pass
 
+    def attempt_to_set_statement_timeout(
+        self, cursor: sqlite3.Cursor, statement_timeout: int, for_transaction: bool
+    ) -> None:
+        # Not supported.
+        pass
+
     @staticmethod
     def executescript(cursor: sqlite3.Cursor, script: str) -> None:
         """Execute a chunk of SQL containing multiple semicolon-delimited statements.