diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 79ec8f119d..0ba3a025cf 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -32,6 +32,7 @@ from typing import (
overload,
)
+import attr
from prometheus_client import Histogram
from typing_extensions import Literal
@@ -90,13 +91,17 @@ def make_pool(
return adbapi.ConnectionPool(
db_config.config["name"],
cp_reactor=reactor,
- cp_openfun=engine.on_new_connection,
+ cp_openfun=lambda conn: engine.on_new_connection(
+ LoggingDatabaseConnection(conn, engine, "on_new_connection")
+ ),
**db_config.config.get("args", {})
)
def make_conn(
- db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+ db_config: DatabaseConnectionConfig,
+ engine: BaseDatabaseEngine,
+ default_txn_name: str,
) -> Connection:
"""Make a new connection to the database and return it.
@@ -109,11 +114,60 @@ def make_conn(
for k, v in db_config.config.get("args", {}).items()
if not k.startswith("cp_")
}
- db_conn = engine.module.connect(**db_params)
+ native_db_conn = engine.module.connect(**db_params)
+ db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name)
+
engine.on_new_connection(db_conn)
return db_conn
+@attr.s(slots=True)
+class LoggingDatabaseConnection:
+ """A wrapper around a database connection that returns `LoggingTransaction`
+ as its cursor class.
+
+ This is mainly used on startup to ensure that queries get logged correctly
+ """
+
+ conn = attr.ib(type=Connection)
+ engine = attr.ib(type=BaseDatabaseEngine)
+ default_txn_name = attr.ib(type=str)
+
+ def cursor(
+ self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
+ ) -> "LoggingTransaction":
+ if not txn_name:
+ txn_name = self.default_txn_name
+
+ return LoggingTransaction(
+ self.conn.cursor(),
+ name=txn_name,
+ database_engine=self.engine,
+ after_callbacks=after_callbacks,
+ exception_callbacks=exception_callbacks,
+ )
+
+ def close(self) -> None:
+ self.conn.close()
+
+ def commit(self) -> None:
+ self.conn.commit()
+
+ def rollback(self, *args, **kwargs) -> None:
+ self.conn.rollback(*args, **kwargs)
+
+ def __enter__(self) -> "Connection":
+ self.conn.__enter__()
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback) -> bool:
+ return self.conn.__exit__(exc_type, exc_value, traceback)
+
+ # Proxy through any unknown lookups to the DB conn class.
+ def __getattr__(self, name):
+ return getattr(self.conn, name)
+
+
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
#
# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
@@ -247,6 +301,12 @@ class LoggingTransaction:
def close(self) -> None:
self.txn.close()
+ def __enter__(self) -> "LoggingTransaction":
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.close()
+
class PerformanceCounters:
def __init__(self):
@@ -395,7 +455,7 @@ class DatabasePool:
def new_transaction(
self,
- conn: Connection,
+ conn: LoggingDatabaseConnection,
desc: str,
after_callbacks: List[_CallbackListEntry],
exception_callbacks: List[_CallbackListEntry],
@@ -403,6 +463,24 @@ class DatabasePool:
*args: Any,
**kwargs: Any
) -> R:
+ """Start a new database transaction with the given connection.
+
+ Note: The given func may be called multiple times under certain
+ failure modes. This is normally fine when in a standard transaction,
+ but care must be taken if the connection is in `autocommit` mode that
+ the function will correctly handle being aborted and retried half way
+ through its execution.
+
+ Args:
+ conn
+ desc
+ after_callbacks
+ exception_callbacks
+ func
+ *args
+ **kwargs
+ """
+
start = monotonic_time()
txn_id = self._TXN_ID
@@ -418,12 +496,10 @@ class DatabasePool:
i = 0
N = 5
while True:
- cursor = LoggingTransaction(
- conn.cursor(),
- name,
- self.engine,
- after_callbacks,
- exception_callbacks,
+ cursor = conn.cursor(
+ txn_name=name,
+ after_callbacks=after_callbacks,
+ exception_callbacks=exception_callbacks,
)
try:
r = func(cursor, *args, **kwargs)
@@ -508,7 +584,12 @@ class DatabasePool:
sql_txn_timer.labels(desc).observe(duration)
async def runInteraction(
- self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any
+ self,
+ desc: str,
+ func: "Callable[..., R]",
+ *args: Any,
+ db_autocommit: bool = False,
+ **kwargs: Any
) -> R:
"""Starts a transaction on the database and runs a given function
@@ -518,6 +599,18 @@ class DatabasePool:
database transaction (twisted.enterprise.adbapi.Transaction) as
its first argument, followed by `args` and `kwargs`.
+ db_autocommit: Whether to run the function in "autocommit" mode,
+ i.e. outside of a transaction. This is useful for transactions
+ that are only a single query.
+
+ Currently, this is only implemented for Postgres. SQLite will still
+ run the function inside a transaction.
+
+ WARNING: This means that if func fails half way through then
+ the changes will *not* be rolled back. `func` may also get
+ called multiple times if the transaction is retried, so must
+ correctly handle that case.
+
args: positional args to pass to `func`
kwargs: named args to pass to `func`
@@ -538,6 +631,7 @@ class DatabasePool:
exception_callbacks,
func,
*args,
+ db_autocommit=db_autocommit,
**kwargs
)
@@ -551,7 +645,11 @@ class DatabasePool:
return cast(R, result)
async def runWithConnection(
- self, func: "Callable[..., R]", *args: Any, **kwargs: Any
+ self,
+ func: "Callable[..., R]",
+ *args: Any,
+ db_autocommit: bool = False,
+ **kwargs: Any
) -> R:
"""Wraps the .runWithConnection() method on the underlying db_pool.
@@ -560,6 +658,9 @@ class DatabasePool:
database connection (twisted.enterprise.adbapi.Connection) as
its first argument, followed by `args` and `kwargs`.
args: positional args to pass to `func`
+ db_autocommit: Whether to run the function in "autocommit" mode,
+ i.e. outside of a transaction. This is useful for transaction
+ that are only a single query. Currently only affects postgres.
kwargs: named args to pass to `func`
Returns:
@@ -575,6 +676,13 @@ class DatabasePool:
start_time = monotonic_time()
def inner_func(conn, *args, **kwargs):
+ # We shouldn't be in a transaction. If we are then something
+ # somewhere hasn't committed after doing work. (This is likely only
+ # possible during startup, as `run*` will ensure changes are
+ # committed/rolled back before putting the connection back in the
+ # pool).
+ assert not self.engine.in_transaction(conn)
+
with LoggingContext("runWithConnection", parent_context) as context:
sched_duration_sec = monotonic_time() - start_time
sql_scheduling_timer.observe(sched_duration_sec)
@@ -584,7 +692,17 @@ class DatabasePool:
logger.debug("Reconnecting closed database connection")
conn.reconnect()
- return func(conn, *args, **kwargs)
+ try:
+ if db_autocommit:
+ self.engine.attempt_to_set_autocommit(conn, True)
+
+ db_conn = LoggingDatabaseConnection(
+ conn, self.engine, "runWithConnection"
+ )
+ return func(db_conn, *args, **kwargs)
+ finally:
+ if db_autocommit:
+ self.engine.attempt_to_set_autocommit(conn, False)
return await make_deferred_yieldable(
self._db_pool.runWithConnection(inner_func, *args, **kwargs)
@@ -1621,7 +1739,7 @@ class DatabasePool:
def get_cache_dict(
self,
- db_conn: Connection,
+ db_conn: LoggingDatabaseConnection,
table: str,
entity_column: str,
stream_column: str,
@@ -1642,9 +1760,7 @@ class DatabasePool:
"limit": limit,
}
- sql = self.engine.convert_param_style(sql)
-
- txn = db_conn.cursor()
+ txn = db_conn.cursor(txn_name="get_cache_dict")
txn.execute(sql, (int(max_value),))
cache = {row[0]: int(row[1]) for row in txn}
|