summary refs log tree commit diff
path: root/synapse/storage/database.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/database.py')
-rw-r--r--synapse/storage/database.py150
1 files changed, 133 insertions, 17 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 79ec8f119d..0ba3a025cf 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -32,6 +32,7 @@ from typing import (
     overload,
 )
 
+import attr
 from prometheus_client import Histogram
 from typing_extensions import Literal
 
@@ -90,13 +91,17 @@ def make_pool(
     return adbapi.ConnectionPool(
         db_config.config["name"],
         cp_reactor=reactor,
-        cp_openfun=engine.on_new_connection,
+        cp_openfun=lambda conn: engine.on_new_connection(
+            LoggingDatabaseConnection(conn, engine, "on_new_connection")
+        ),
         **db_config.config.get("args", {})
     )
 
 
 def make_conn(
-    db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+    db_config: DatabaseConnectionConfig,
+    engine: BaseDatabaseEngine,
+    default_txn_name: str,
 ) -> Connection:
     """Make a new connection to the database and return it.
 
@@ -109,11 +114,60 @@ def make_conn(
         for k, v in db_config.config.get("args", {}).items()
         if not k.startswith("cp_")
     }
-    db_conn = engine.module.connect(**db_params)
+    native_db_conn = engine.module.connect(**db_params)
+    db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name)
+
     engine.on_new_connection(db_conn)
     return db_conn
 
 
+@attr.s(slots=True)
+class LoggingDatabaseConnection:
+    """A wrapper around a database connection that returns `LoggingTransaction`
+    as its cursor class.
+
+    This is mainly used on startup to ensure that queries get logged correctly
+    """
+
+    conn = attr.ib(type=Connection)
+    engine = attr.ib(type=BaseDatabaseEngine)
+    default_txn_name = attr.ib(type=str)
+
+    def cursor(
+        self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
+    ) -> "LoggingTransaction":
+        if not txn_name:
+            txn_name = self.default_txn_name
+
+        return LoggingTransaction(
+            self.conn.cursor(),
+            name=txn_name,
+            database_engine=self.engine,
+            after_callbacks=after_callbacks,
+            exception_callbacks=exception_callbacks,
+        )
+
+    def close(self) -> None:
+        self.conn.close()
+
+    def commit(self) -> None:
+        self.conn.commit()
+
+    def rollback(self, *args, **kwargs) -> None:
+        self.conn.rollback(*args, **kwargs)
+
+    def __enter__(self) -> "Connection":
+        self.conn.__enter__()
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback) -> bool:
+        return self.conn.__exit__(exc_type, exc_value, traceback)
+
+    # Proxy through any unknown lookups to the DB conn class.
+    def __getattr__(self, name):
+        return getattr(self.conn, name)
+
+
 # The type of entry which goes on our after_callbacks and exception_callbacks lists.
 #
 # Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
@@ -247,6 +301,12 @@ class LoggingTransaction:
     def close(self) -> None:
         self.txn.close()
 
+    def __enter__(self) -> "LoggingTransaction":
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        self.close()
+
 
 class PerformanceCounters:
     def __init__(self):
@@ -395,7 +455,7 @@ class DatabasePool:
 
     def new_transaction(
         self,
-        conn: Connection,
+        conn: LoggingDatabaseConnection,
         desc: str,
         after_callbacks: List[_CallbackListEntry],
         exception_callbacks: List[_CallbackListEntry],
@@ -403,6 +463,24 @@ class DatabasePool:
         *args: Any,
         **kwargs: Any
     ) -> R:
+        """Start a new database transaction with the given connection.
+
+        Note: The given func may be called multiple times under certain
+        failure modes. This is normally fine when in a standard transaction,
+        but care must be taken if the connection is in `autocommit` mode that
+        the function will correctly handle being aborted and retried half way
+        through its execution.
+
+        Args:
+            conn
+            desc
+            after_callbacks
+            exception_callbacks
+            func
+            *args
+            **kwargs
+        """
+
         start = monotonic_time()
         txn_id = self._TXN_ID
 
@@ -418,12 +496,10 @@ class DatabasePool:
             i = 0
             N = 5
             while True:
-                cursor = LoggingTransaction(
-                    conn.cursor(),
-                    name,
-                    self.engine,
-                    after_callbacks,
-                    exception_callbacks,
+                cursor = conn.cursor(
+                    txn_name=name,
+                    after_callbacks=after_callbacks,
+                    exception_callbacks=exception_callbacks,
                 )
                 try:
                     r = func(cursor, *args, **kwargs)
@@ -508,7 +584,12 @@ class DatabasePool:
             sql_txn_timer.labels(desc).observe(duration)
 
     async def runInteraction(
-        self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any
+        self,
+        desc: str,
+        func: "Callable[..., R]",
+        *args: Any,
+        db_autocommit: bool = False,
+        **kwargs: Any
     ) -> R:
         """Starts a transaction on the database and runs a given function
 
@@ -518,6 +599,18 @@ class DatabasePool:
                 database transaction (twisted.enterprise.adbapi.Transaction) as
                 its first argument, followed by `args` and `kwargs`.
 
+            db_autocommit: Whether to run the function in "autocommit" mode,
+                i.e. outside of a transaction. This is useful for transactions
+                that are only a single query.
+
+                Currently, this is only implemented for Postgres. SQLite will still
+                run the function inside a transaction.
+
+                WARNING: This means that if func fails half way through then
+                the changes will *not* be rolled back. `func` may also get
+                called multiple times if the transaction is retried, so must
+                correctly handle that case.
+
             args: positional args to pass to `func`
             kwargs: named args to pass to `func`
 
@@ -538,6 +631,7 @@ class DatabasePool:
                 exception_callbacks,
                 func,
                 *args,
+                db_autocommit=db_autocommit,
                 **kwargs
             )
 
@@ -551,7 +645,11 @@ class DatabasePool:
         return cast(R, result)
 
     async def runWithConnection(
-        self, func: "Callable[..., R]", *args: Any, **kwargs: Any
+        self,
+        func: "Callable[..., R]",
+        *args: Any,
+        db_autocommit: bool = False,
+        **kwargs: Any
     ) -> R:
         """Wraps the .runWithConnection() method on the underlying db_pool.
 
@@ -560,6 +658,9 @@ class DatabasePool:
                 database connection (twisted.enterprise.adbapi.Connection) as
                 its first argument, followed by `args` and `kwargs`.
             args: positional args to pass to `func`
+            db_autocommit: Whether to run the function in "autocommit" mode,
+                i.e. outside of a transaction. This is useful for transaction
+                that are only a single query. Currently only affects postgres.
             kwargs: named args to pass to `func`
 
         Returns:
@@ -575,6 +676,13 @@ class DatabasePool:
         start_time = monotonic_time()
 
         def inner_func(conn, *args, **kwargs):
+            # We shouldn't be in a transaction. If we are then something
+            # somewhere hasn't committed after doing work. (This is likely only
+            # possible during startup, as `run*` will ensure changes are
+            # committed/rolled back before putting the connection back in the
+            # pool).
+            assert not self.engine.in_transaction(conn)
+
             with LoggingContext("runWithConnection", parent_context) as context:
                 sched_duration_sec = monotonic_time() - start_time
                 sql_scheduling_timer.observe(sched_duration_sec)
@@ -584,7 +692,17 @@ class DatabasePool:
                     logger.debug("Reconnecting closed database connection")
                     conn.reconnect()
 
-                return func(conn, *args, **kwargs)
+                try:
+                    if db_autocommit:
+                        self.engine.attempt_to_set_autocommit(conn, True)
+
+                    db_conn = LoggingDatabaseConnection(
+                        conn, self.engine, "runWithConnection"
+                    )
+                    return func(db_conn, *args, **kwargs)
+                finally:
+                    if db_autocommit:
+                        self.engine.attempt_to_set_autocommit(conn, False)
 
         return await make_deferred_yieldable(
             self._db_pool.runWithConnection(inner_func, *args, **kwargs)
@@ -1621,7 +1739,7 @@ class DatabasePool:
 
     def get_cache_dict(
         self,
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
         table: str,
         entity_column: str,
         stream_column: str,
@@ -1642,9 +1760,7 @@ class DatabasePool:
             "limit": limit,
         }
 
-        sql = self.engine.convert_param_style(sql)
-
-        txn = db_conn.cursor()
+        txn = db_conn.cursor(txn_name="get_cache_dict")
         txn.execute(sql, (int(max_value),))
 
         cache = {row[0]: int(row[1]) for row in txn}