diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/storage/background_updates.py | 13 | ||||
-rw-r--r-- | synapse/storage/engines/postgres.py | 15 | ||||
-rw-r--r-- | synapse/storage/engines/psycopg.py | 8 | ||||
-rw-r--r-- | synapse/storage/engines/psycopg2.py | 5 |
4 files changed, 26 insertions, 15 deletions
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 62fbd05534..58284223f5 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -755,6 +755,8 @@ class BackgroundUpdater: # postgres insists on autocommit for the index conn.engine.attempt_to_set_autocommit(conn.conn, True) + assert isinstance(self.db_pool.engine, PostgresEngine) + try: c = conn.cursor() @@ -768,8 +770,7 @@ class BackgroundUpdater: # override the global statement timeout to avoid accidentally squashing # a long-running index creation process - timeout_sql = "SET SESSION statement_timeout = 0" - c.execute(timeout_sql) + self.db_pool.engine.set_statement_timeout(c, 0) sql = ( "CREATE %(unique)s INDEX CONCURRENTLY %(name)s" @@ -791,11 +792,11 @@ class BackgroundUpdater: logger.debug("[SQL] %s", sql) c.execute(sql) finally: - # mypy ignore - `statement_timeout` is defined on PostgresEngine # reset the global timeout to the default - default_timeout = self.db_pool.engine.statement_timeout # type: ignore[attr-defined] - undo_timeout_sql = f"SET statement_timeout = {default_timeout}" - conn.cursor().execute(undo_timeout_sql) + if self.db_pool.engine.statement_timeout is not None: + self.db_pool.engine.set_statement_timeout( + conn.cursor(), self.db_pool.engine.statement_timeout + ) conn.engine.attempt_to_set_autocommit(conn.conn, False) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 911abddc19..05a5330ed7 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -64,6 +64,11 @@ class PostgresEngine( """ ... + @abc.abstractmethod + def set_statement_timeout(self, cursor: CursorType, statement_timeout: int) -> None: + """Configure the current cursor's statement timeout.""" + ... + @property def single_threaded(self) -> bool: return False @@ -168,15 +173,7 @@ class PostgresEngine( # Abort really long-running statements and turn them into errors. if self.statement_timeout is not None: - # TODO Avoid a circular import, this needs to be abstracted. - if self.__class__.__name__ == "Psycopg2Engine": - cursor.execute("SET statement_timeout TO ?", (self.statement_timeout,)) - else: - cursor.execute( - sql.SQL("SET statement_timeout TO {}").format( - self.statement_timeout - ) - ) + self.set_statement_timeout(cursor.txn, self.statement_timeout) # type: ignore[arg-type] cursor.close() db_conn.commit() diff --git a/synapse/storage/engines/psycopg.py b/synapse/storage/engines/psycopg.py index 8d054ab6df..6dd01319e1 100644 --- a/synapse/storage/engines/psycopg.py +++ b/synapse/storage/engines/psycopg.py @@ -52,6 +52,14 @@ class PsycopgEngine( def get_server_version(self, db_conn: psycopg.Connection) -> int: return db_conn.info.server_version + def set_statement_timeout( + self, cursor: psycopg.Cursor, statement_timeout: int + ) -> None: + """Configure the current cursor's statement timeout.""" + cursor.execute( + psycopg.sql.SQL("SET statement_timeout TO {}").format(statement_timeout) + ) + def convert_param_style(self, sql: str) -> str: # if isinstance(sql, psycopg.sql.Composed): # return sql diff --git a/synapse/storage/engines/psycopg2.py b/synapse/storage/engines/psycopg2.py index e8af8c2c48..efb66778f9 100644 --- a/synapse/storage/engines/psycopg2.py +++ b/synapse/storage/engines/psycopg2.py @@ -51,6 +51,11 @@ class Psycopg2Engine( def get_server_version(self, db_conn: psycopg2.extensions.connection) -> int: return db_conn.server_version + def set_statement_timeout( + self, cursor: psycopg2.extensions.cursor, statement_timeout: int + ) -> None: + cursor.execute("SET statement_timeout TO ?", (statement_timeout,)) + def convert_param_style(self, sql: str) -> str: return sql.replace("?", "%s") |