summary refs log tree commit diff
path: root/synapse/storage/database.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/database.py')
-rw-r--r--synapse/storage/database.py89
1 files changed, 74 insertions, 15 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 79ec8f119d..0d9d9b7cc0 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -32,6 +32,7 @@ from typing import (
     overload,
 )
 
+import attr
 from prometheus_client import Histogram
 from typing_extensions import Literal
 
@@ -90,13 +91,17 @@ def make_pool(
     return adbapi.ConnectionPool(
         db_config.config["name"],
         cp_reactor=reactor,
-        cp_openfun=engine.on_new_connection,
+        cp_openfun=lambda conn: engine.on_new_connection(
+            LoggingDatabaseConnection(conn, engine, "on_new_connection")
+        ),
         **db_config.config.get("args", {})
     )
 
 
 def make_conn(
-    db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+    db_config: DatabaseConnectionConfig,
+    engine: BaseDatabaseEngine,
+    default_txn_name: str,
 ) -> Connection:
     """Make a new connection to the database and return it.
 
@@ -109,11 +114,60 @@ def make_conn(
         for k, v in db_config.config.get("args", {}).items()
         if not k.startswith("cp_")
     }
-    db_conn = engine.module.connect(**db_params)
+    native_db_conn = engine.module.connect(**db_params)
+    db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name)
+
     engine.on_new_connection(db_conn)
     return db_conn
 
 
+@attr.s(slots=True)
+class LoggingDatabaseConnection:
+    """A wrapper around a database connection that returns `LoggingTransaction`
+    as its cursor class.
+
+    This is mainly used on startup to ensure that queries get logged correctly
+    """
+
+    conn = attr.ib(type=Connection)
+    engine = attr.ib(type=BaseDatabaseEngine)
+    default_txn_name = attr.ib(type=str)
+
+    def cursor(
+        self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
+    ) -> "LoggingTransaction":
+        if not txn_name:
+            txn_name = self.default_txn_name
+
+        return LoggingTransaction(
+            self.conn.cursor(),
+            name=txn_name,
+            database_engine=self.engine,
+            after_callbacks=after_callbacks,
+            exception_callbacks=exception_callbacks,
+        )
+
+    def close(self) -> None:
+        self.conn.close()
+
+    def commit(self) -> None:
+        self.conn.commit()
+
+    def rollback(self, *args, **kwargs) -> None:
+        self.conn.rollback(*args, **kwargs)
+
+    def __enter__(self) -> "Connection":
+        self.conn.__enter__()
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback) -> bool:
+        return self.conn.__exit__(exc_type, exc_value, traceback)
+
+    # Proxy through any unknown lookups to the DB conn class.
+    def __getattr__(self, name):
+        return getattr(self.conn, name)
+
+
 # The type of entry which goes on our after_callbacks and exception_callbacks lists.
 #
 # Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
@@ -247,6 +301,12 @@ class LoggingTransaction:
     def close(self) -> None:
         self.txn.close()
 
+    def __enter__(self) -> "LoggingTransaction":
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        self.close()
+
 
 class PerformanceCounters:
     def __init__(self):
@@ -395,7 +455,7 @@ class DatabasePool:
 
     def new_transaction(
         self,
-        conn: Connection,
+        conn: LoggingDatabaseConnection,
         desc: str,
         after_callbacks: List[_CallbackListEntry],
         exception_callbacks: List[_CallbackListEntry],
@@ -418,12 +478,10 @@ class DatabasePool:
             i = 0
             N = 5
             while True:
-                cursor = LoggingTransaction(
-                    conn.cursor(),
-                    name,
-                    self.engine,
-                    after_callbacks,
-                    exception_callbacks,
+                cursor = conn.cursor(
+                    txn_name=name,
+                    after_callbacks=after_callbacks,
+                    exception_callbacks=exception_callbacks,
                 )
                 try:
                     r = func(cursor, *args, **kwargs)
@@ -584,7 +642,10 @@ class DatabasePool:
                     logger.debug("Reconnecting closed database connection")
                     conn.reconnect()
 
-                return func(conn, *args, **kwargs)
+                db_conn = LoggingDatabaseConnection(
+                    conn, self.engine, "runWithConnection"
+                )
+                return func(db_conn, *args, **kwargs)
 
         return await make_deferred_yieldable(
             self._db_pool.runWithConnection(inner_func, *args, **kwargs)
@@ -1621,7 +1682,7 @@ class DatabasePool:
 
     def get_cache_dict(
         self,
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
         table: str,
         entity_column: str,
         stream_column: str,
@@ -1642,9 +1703,7 @@ class DatabasePool:
             "limit": limit,
         }
 
-        sql = self.engine.convert_param_style(sql)
-
-        txn = db_conn.cursor()
+        txn = db_conn.cursor(txn_name="get_cache_dict")
         txn.execute(sql, (int(max_value),))
 
         cache = {row[0]: int(row[1]) for row in txn}