summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/background_updates.py13
-rw-r--r--synapse/storage/engines/postgres.py15
-rw-r--r--synapse/storage/engines/psycopg.py8
-rw-r--r--synapse/storage/engines/psycopg2.py5
4 files changed, 26 insertions, 15 deletions
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 62fbd05534..58284223f5 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -755,6 +755,8 @@ class BackgroundUpdater:
             # postgres insists on autocommit for the index
             conn.engine.attempt_to_set_autocommit(conn.conn, True)
 
+            assert isinstance(self.db_pool.engine, PostgresEngine)
+
             try:
                 c = conn.cursor()
 
@@ -768,8 +770,7 @@ class BackgroundUpdater:
 
                 # override the global statement timeout to avoid accidentally squashing
                 # a long-running index creation process
-                timeout_sql = "SET SESSION statement_timeout = 0"
-                c.execute(timeout_sql)
+                self.db_pool.engine.set_statement_timeout(c, 0)
 
                 sql = (
                     "CREATE %(unique)s INDEX CONCURRENTLY %(name)s"
@@ -791,11 +792,11 @@ class BackgroundUpdater:
                     logger.debug("[SQL] %s", sql)
                     c.execute(sql)
             finally:
-                # mypy ignore - `statement_timeout` is defined on PostgresEngine
                 # reset the global timeout to the default
-                default_timeout = self.db_pool.engine.statement_timeout  # type: ignore[attr-defined]
-                undo_timeout_sql = f"SET statement_timeout = {default_timeout}"
-                conn.cursor().execute(undo_timeout_sql)
+                if self.db_pool.engine.statement_timeout is not None:
+                    self.db_pool.engine.set_statement_timeout(
+                        conn.cursor(), self.db_pool.engine.statement_timeout
+                    )
 
                 conn.engine.attempt_to_set_autocommit(conn.conn, False)
 
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 911abddc19..05a5330ed7 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -64,6 +64,11 @@ class PostgresEngine(
         """
         ...
 
+    @abc.abstractmethod
+    def set_statement_timeout(self, cursor: CursorType, statement_timeout: int) -> None:
+        """Configure the current cursor's statement timeout."""
+        ...
+
     @property
     def single_threaded(self) -> bool:
         return False
@@ -168,15 +173,7 @@ class PostgresEngine(
 
         # Abort really long-running statements and turn them into errors.
         if self.statement_timeout is not None:
-            # TODO Avoid a circular import, this needs to be abstracted.
-            if self.__class__.__name__ == "Psycopg2Engine":
-                cursor.execute("SET statement_timeout TO ?", (self.statement_timeout,))
-            else:
-                cursor.execute(
-                    sql.SQL("SET statement_timeout TO {}").format(
-                        self.statement_timeout
-                    )
-                )
+            self.set_statement_timeout(cursor.txn, self.statement_timeout)  # type: ignore[arg-type]
 
         cursor.close()
         db_conn.commit()
diff --git a/synapse/storage/engines/psycopg.py b/synapse/storage/engines/psycopg.py
index 8d054ab6df..6dd01319e1 100644
--- a/synapse/storage/engines/psycopg.py
+++ b/synapse/storage/engines/psycopg.py
@@ -52,6 +52,14 @@ class PsycopgEngine(
     def get_server_version(self, db_conn: psycopg.Connection) -> int:
         return db_conn.info.server_version
 
+    def set_statement_timeout(
+        self, cursor: psycopg.Cursor, statement_timeout: int
+    ) -> None:
+        """Configure the current cursor's statement timeout."""
+        cursor.execute(
+            psycopg.sql.SQL("SET statement_timeout TO {}").format(statement_timeout)
+        )
+
     def convert_param_style(self, sql: str) -> str:
         # if isinstance(sql, psycopg.sql.Composed):
         #     return sql
diff --git a/synapse/storage/engines/psycopg2.py b/synapse/storage/engines/psycopg2.py
index e8af8c2c48..efb66778f9 100644
--- a/synapse/storage/engines/psycopg2.py
+++ b/synapse/storage/engines/psycopg2.py
@@ -51,6 +51,11 @@ class Psycopg2Engine(
     def get_server_version(self, db_conn: psycopg2.extensions.connection) -> int:
         return db_conn.server_version
 
+    def set_statement_timeout(
+        self, cursor: psycopg2.extensions.cursor, statement_timeout: int
+    ) -> None:
+        cursor.execute("SET statement_timeout TO ?", (statement_timeout,))
+
     def convert_param_style(self, sql: str) -> str:
         return sql.replace("?", "%s")