diff options
Diffstat (limited to 'synapse/storage/database.py')
-rw-r--r-- | synapse/storage/database.py | 150 |
1 files changed, 133 insertions, 17 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 79ec8f119d..0ba3a025cf 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -32,6 +32,7 @@ from typing import ( overload, ) +import attr from prometheus_client import Histogram from typing_extensions import Literal @@ -90,13 +91,17 @@ def make_pool( return adbapi.ConnectionPool( db_config.config["name"], cp_reactor=reactor, - cp_openfun=engine.on_new_connection, + cp_openfun=lambda conn: engine.on_new_connection( + LoggingDatabaseConnection(conn, engine, "on_new_connection") + ), **db_config.config.get("args", {}) ) def make_conn( - db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine + db_config: DatabaseConnectionConfig, + engine: BaseDatabaseEngine, + default_txn_name: str, ) -> Connection: """Make a new connection to the database and return it. @@ -109,11 +114,60 @@ def make_conn( for k, v in db_config.config.get("args", {}).items() if not k.startswith("cp_") } - db_conn = engine.module.connect(**db_params) + native_db_conn = engine.module.connect(**db_params) + db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name) + engine.on_new_connection(db_conn) return db_conn +@attr.s(slots=True) +class LoggingDatabaseConnection: + """A wrapper around a database connection that returns `LoggingTransaction` + as its cursor class. + + This is mainly used on startup to ensure that queries get logged correctly + """ + + conn = attr.ib(type=Connection) + engine = attr.ib(type=BaseDatabaseEngine) + default_txn_name = attr.ib(type=str) + + def cursor( + self, *, txn_name=None, after_callbacks=None, exception_callbacks=None + ) -> "LoggingTransaction": + if not txn_name: + txn_name = self.default_txn_name + + return LoggingTransaction( + self.conn.cursor(), + name=txn_name, + database_engine=self.engine, + after_callbacks=after_callbacks, + exception_callbacks=exception_callbacks, + ) + + def close(self) -> None: + self.conn.close() + + def commit(self) -> None: + self.conn.commit() + + def rollback(self, *args, **kwargs) -> None: + self.conn.rollback(*args, **kwargs) + + def __enter__(self) -> "Connection": + self.conn.__enter__() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> bool: + return self.conn.__exit__(exc_type, exc_value, traceback) + + # Proxy through any unknown lookups to the DB conn class. + def __getattr__(self, name): + return getattr(self.conn, name) + + # The type of entry which goes on our after_callbacks and exception_callbacks lists. # # Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so @@ -247,6 +301,12 @@ class LoggingTransaction: def close(self) -> None: self.txn.close() + def __enter__(self) -> "LoggingTransaction": + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + class PerformanceCounters: def __init__(self): @@ -395,7 +455,7 @@ class DatabasePool: def new_transaction( self, - conn: Connection, + conn: LoggingDatabaseConnection, desc: str, after_callbacks: List[_CallbackListEntry], exception_callbacks: List[_CallbackListEntry], @@ -403,6 +463,24 @@ class DatabasePool: *args: Any, **kwargs: Any ) -> R: + """Start a new database transaction with the given connection. + + Note: The given func may be called multiple times under certain + failure modes. This is normally fine when in a standard transaction, + but care must be taken if the connection is in `autocommit` mode that + the function will correctly handle being aborted and retried half way + through its execution. + + Args: + conn + desc + after_callbacks + exception_callbacks + func + *args + **kwargs + """ + start = monotonic_time() txn_id = self._TXN_ID @@ -418,12 +496,10 @@ class DatabasePool: i = 0 N = 5 while True: - cursor = LoggingTransaction( - conn.cursor(), - name, - self.engine, - after_callbacks, - exception_callbacks, + cursor = conn.cursor( + txn_name=name, + after_callbacks=after_callbacks, + exception_callbacks=exception_callbacks, ) try: r = func(cursor, *args, **kwargs) @@ -508,7 +584,12 @@ class DatabasePool: sql_txn_timer.labels(desc).observe(duration) async def runInteraction( - self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any + self, + desc: str, + func: "Callable[..., R]", + *args: Any, + db_autocommit: bool = False, + **kwargs: Any ) -> R: """Starts a transaction on the database and runs a given function @@ -518,6 +599,18 @@ class DatabasePool: database transaction (twisted.enterprise.adbapi.Transaction) as its first argument, followed by `args` and `kwargs`. + db_autocommit: Whether to run the function in "autocommit" mode, + i.e. outside of a transaction. This is useful for transactions + that are only a single query. + + Currently, this is only implemented for Postgres. SQLite will still + run the function inside a transaction. + + WARNING: This means that if func fails half way through then + the changes will *not* be rolled back. `func` may also get + called multiple times if the transaction is retried, so must + correctly handle that case. + args: positional args to pass to `func` kwargs: named args to pass to `func` @@ -538,6 +631,7 @@ class DatabasePool: exception_callbacks, func, *args, + db_autocommit=db_autocommit, **kwargs ) @@ -551,7 +645,11 @@ class DatabasePool: return cast(R, result) async def runWithConnection( - self, func: "Callable[..., R]", *args: Any, **kwargs: Any + self, + func: "Callable[..., R]", + *args: Any, + db_autocommit: bool = False, + **kwargs: Any ) -> R: """Wraps the .runWithConnection() method on the underlying db_pool. @@ -560,6 +658,9 @@ class DatabasePool: database connection (twisted.enterprise.adbapi.Connection) as its first argument, followed by `args` and `kwargs`. args: positional args to pass to `func` + db_autocommit: Whether to run the function in "autocommit" mode, + i.e. outside of a transaction. This is useful for transaction + that are only a single query. Currently only affects postgres. kwargs: named args to pass to `func` Returns: @@ -575,6 +676,13 @@ class DatabasePool: start_time = monotonic_time() def inner_func(conn, *args, **kwargs): + # We shouldn't be in a transaction. If we are then something + # somewhere hasn't committed after doing work. (This is likely only + # possible during startup, as `run*` will ensure changes are + # committed/rolled back before putting the connection back in the + # pool). + assert not self.engine.in_transaction(conn) + with LoggingContext("runWithConnection", parent_context) as context: sched_duration_sec = monotonic_time() - start_time sql_scheduling_timer.observe(sched_duration_sec) @@ -584,7 +692,17 @@ class DatabasePool: logger.debug("Reconnecting closed database connection") conn.reconnect() - return func(conn, *args, **kwargs) + try: + if db_autocommit: + self.engine.attempt_to_set_autocommit(conn, True) + + db_conn = LoggingDatabaseConnection( + conn, self.engine, "runWithConnection" + ) + return func(db_conn, *args, **kwargs) + finally: + if db_autocommit: + self.engine.attempt_to_set_autocommit(conn, False) return await make_deferred_yieldable( self._db_pool.runWithConnection(inner_func, *args, **kwargs) @@ -1621,7 +1739,7 @@ class DatabasePool: def get_cache_dict( self, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, table: str, entity_column: str, stream_column: str, @@ -1642,9 +1760,7 @@ class DatabasePool: "limit": limit, } - sql = self.engine.convert_param_style(sql) - - txn = db_conn.cursor() + txn = db_conn.cursor(txn_name="get_cache_dict") txn.execute(sql, (int(max_value),)) cache = {row[0]: int(row[1]) for row in txn} |