diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 911abddc19..05a5330ed7 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -64,6 +64,11 @@ class PostgresEngine(
"""
...
+ @abc.abstractmethod
+ def set_statement_timeout(self, cursor: CursorType, statement_timeout: int) -> None:
+ """Configure the current cursor's statement timeout."""
+ ...
+
@property
def single_threaded(self) -> bool:
return False
@@ -168,15 +173,7 @@ class PostgresEngine(
# Abort really long-running statements and turn them into errors.
if self.statement_timeout is not None:
- # TODO Avoid a circular import, this needs to be abstracted.
- if self.__class__.__name__ == "Psycopg2Engine":
- cursor.execute("SET statement_timeout TO ?", (self.statement_timeout,))
- else:
- cursor.execute(
- sql.SQL("SET statement_timeout TO {}").format(
- self.statement_timeout
- )
- )
+ self.set_statement_timeout(cursor.txn, self.statement_timeout) # type: ignore[arg-type]
cursor.close()
db_conn.commit()
diff --git a/synapse/storage/engines/psycopg.py b/synapse/storage/engines/psycopg.py
index 8d054ab6df..6dd01319e1 100644
--- a/synapse/storage/engines/psycopg.py
+++ b/synapse/storage/engines/psycopg.py
@@ -52,6 +52,14 @@ class PsycopgEngine(
def get_server_version(self, db_conn: psycopg.Connection) -> int:
return db_conn.info.server_version
+ def set_statement_timeout(
+ self, cursor: psycopg.Cursor, statement_timeout: int
+ ) -> None:
+ """Configure the current cursor's statement timeout."""
+ cursor.execute(
+ psycopg.sql.SQL("SET statement_timeout TO {}").format(statement_timeout)
+ )
+
def convert_param_style(self, sql: str) -> str:
# if isinstance(sql, psycopg.sql.Composed):
# return sql
diff --git a/synapse/storage/engines/psycopg2.py b/synapse/storage/engines/psycopg2.py
index e8af8c2c48..efb66778f9 100644
--- a/synapse/storage/engines/psycopg2.py
+++ b/synapse/storage/engines/psycopg2.py
@@ -51,6 +51,11 @@ class Psycopg2Engine(
def get_server_version(self, db_conn: psycopg2.extensions.connection) -> int:
return db_conn.server_version
+ def set_statement_timeout(
+ self, cursor: psycopg2.extensions.cursor, statement_timeout: int
+ ) -> None:
+ cursor.execute("SET statement_timeout TO ?", (statement_timeout,))
+
def convert_param_style(self, sql: str) -> str:
return sql.replace("?", "%s")
|