diff options
Diffstat (limited to 'synapse/storage/engines')
-rw-r--r-- | synapse/storage/engines/postgres.py | 15 | ||||
-rw-r--r-- | synapse/storage/engines/psycopg.py | 8 | ||||
-rw-r--r-- | synapse/storage/engines/psycopg2.py | 5 |
3 files changed, 19 insertions, 9 deletions
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 911abddc19..05a5330ed7 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -64,6 +64,11 @@ class PostgresEngine( """ ... + @abc.abstractmethod + def set_statement_timeout(self, cursor: CursorType, statement_timeout: int) -> None: + """Configure the current cursor's statement timeout.""" + ... + @property def single_threaded(self) -> bool: return False @@ -168,15 +173,7 @@ class PostgresEngine( # Abort really long-running statements and turn them into errors. if self.statement_timeout is not None: - # TODO Avoid a circular import, this needs to be abstracted. - if self.__class__.__name__ == "Psycopg2Engine": - cursor.execute("SET statement_timeout TO ?", (self.statement_timeout,)) - else: - cursor.execute( - sql.SQL("SET statement_timeout TO {}").format( - self.statement_timeout - ) - ) + self.set_statement_timeout(cursor.txn, self.statement_timeout) # type: ignore[arg-type] cursor.close() db_conn.commit() diff --git a/synapse/storage/engines/psycopg.py b/synapse/storage/engines/psycopg.py index 8d054ab6df..6dd01319e1 100644 --- a/synapse/storage/engines/psycopg.py +++ b/synapse/storage/engines/psycopg.py @@ -52,6 +52,14 @@ class PsycopgEngine( def get_server_version(self, db_conn: psycopg.Connection) -> int: return db_conn.info.server_version + def set_statement_timeout( + self, cursor: psycopg.Cursor, statement_timeout: int + ) -> None: + """Configure the current cursor's statement timeout.""" + cursor.execute( + psycopg.sql.SQL("SET statement_timeout TO {}").format(statement_timeout) + ) + def convert_param_style(self, sql: str) -> str: # if isinstance(sql, psycopg.sql.Composed): # return sql diff --git a/synapse/storage/engines/psycopg2.py b/synapse/storage/engines/psycopg2.py index e8af8c2c48..efb66778f9 100644 --- a/synapse/storage/engines/psycopg2.py +++ b/synapse/storage/engines/psycopg2.py @@ -51,6 +51,11 @@ class Psycopg2Engine( def get_server_version(self, db_conn: psycopg2.extensions.connection) -> int: return db_conn.server_version + def set_statement_timeout( + self, cursor: psycopg2.extensions.cursor, statement_timeout: int + ) -> None: + cursor.execute("SET statement_timeout TO ?", (statement_timeout,)) + def convert_param_style(self, sql: str) -> str: return sql.replace("?", "%s") |