summary refs log tree commit diff
path: root/synapse/storage/database.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/database.py')
-rw-r--r--synapse/storage/database.py197
1 files changed, 172 insertions, 25 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 6116191b16..d1b5760c2c 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -32,6 +32,7 @@ from typing import (
     overload,
 )
 
+import attr
 from prometheus_client import Histogram
 from typing_extensions import Literal
 
@@ -87,16 +88,25 @@ def make_pool(
     """Get the connection pool for the database.
     """
 
+    # By default enable `cp_reconnect`. We need to fiddle with db_args in case
+    # someone has explicitly set `cp_reconnect`.
+    db_args = dict(db_config.config.get("args", {}))
+    db_args.setdefault("cp_reconnect", True)
+
     return adbapi.ConnectionPool(
         db_config.config["name"],
         cp_reactor=reactor,
-        cp_openfun=engine.on_new_connection,
-        **db_config.config.get("args", {})
+        cp_openfun=lambda conn: engine.on_new_connection(
+            LoggingDatabaseConnection(conn, engine, "on_new_connection")
+        ),
+        **db_args,
     )
 
 
 def make_conn(
-    db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+    db_config: DatabaseConnectionConfig,
+    engine: BaseDatabaseEngine,
+    default_txn_name: str,
 ) -> Connection:
     """Make a new connection to the database and return it.
 
@@ -109,11 +119,60 @@ def make_conn(
         for k, v in db_config.config.get("args", {}).items()
         if not k.startswith("cp_")
     }
-    db_conn = engine.module.connect(**db_params)
+    native_db_conn = engine.module.connect(**db_params)
+    db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name)
+
     engine.on_new_connection(db_conn)
     return db_conn
 
 
+@attr.s(slots=True)
+class LoggingDatabaseConnection:
+    """A wrapper around a database connection that returns `LoggingTransaction`
+    as its cursor class.
+
+    This is mainly used on startup to ensure that queries get logged correctly
+    """
+
+    conn = attr.ib(type=Connection)
+    engine = attr.ib(type=BaseDatabaseEngine)
+    default_txn_name = attr.ib(type=str)
+
+    def cursor(
+        self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
+    ) -> "LoggingTransaction":
+        if not txn_name:
+            txn_name = self.default_txn_name
+
+        return LoggingTransaction(
+            self.conn.cursor(),
+            name=txn_name,
+            database_engine=self.engine,
+            after_callbacks=after_callbacks,
+            exception_callbacks=exception_callbacks,
+        )
+
+    def close(self) -> None:
+        self.conn.close()
+
+    def commit(self) -> None:
+        self.conn.commit()
+
+    def rollback(self, *args, **kwargs) -> None:
+        self.conn.rollback(*args, **kwargs)
+
+    def __enter__(self) -> "Connection":
+        self.conn.__enter__()
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
+        return self.conn.__exit__(exc_type, exc_value, traceback)
+
+    # Proxy through any unknown lookups to the DB conn class.
+    def __getattr__(self, name):
+        return getattr(self.conn, name)
+
+
 # The type of entry which goes on our after_callbacks and exception_callbacks lists.
 #
 # Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
@@ -247,6 +306,12 @@ class LoggingTransaction:
     def close(self) -> None:
         self.txn.close()
 
+    def __enter__(self) -> "LoggingTransaction":
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        self.close()
+
 
 class PerformanceCounters:
     def __init__(self):
@@ -395,7 +460,7 @@ class DatabasePool:
 
     def new_transaction(
         self,
-        conn: Connection,
+        conn: LoggingDatabaseConnection,
         desc: str,
         after_callbacks: List[_CallbackListEntry],
         exception_callbacks: List[_CallbackListEntry],
@@ -436,12 +501,10 @@ class DatabasePool:
             i = 0
             N = 5
             while True:
-                cursor = LoggingTransaction(
-                    conn.cursor(),
-                    name,
-                    self.engine,
-                    after_callbacks,
-                    exception_callbacks,
+                cursor = conn.cursor(
+                    txn_name=name,
+                    after_callbacks=after_callbacks,
+                    exception_callbacks=exception_callbacks,
                 )
                 try:
                     r = func(cursor, *args, **kwargs)
@@ -574,7 +637,7 @@ class DatabasePool:
                 func,
                 *args,
                 db_autocommit=db_autocommit,
-                **kwargs
+                **kwargs,
             )
 
             for after_callback, after_args, after_kwargs in after_callbacks:
@@ -638,7 +701,10 @@ class DatabasePool:
                     if db_autocommit:
                         self.engine.attempt_to_set_autocommit(conn, True)
 
-                    return func(conn, *args, **kwargs)
+                    db_conn = LoggingDatabaseConnection(
+                        conn, self.engine, "runWithConnection"
+                    )
+                    return func(db_conn, *args, **kwargs)
                 finally:
                     if db_autocommit:
                         self.engine.attempt_to_set_autocommit(conn, False)
@@ -832,6 +898,12 @@ class DatabasePool:
         attempts = 0
         while True:
             try:
+                # We can autocommit if we are going to use native upserts
+                autocommit = (
+                    self.engine.can_native_upsert
+                    and table not in self._unsafe_to_upsert_tables
+                )
+
                 return await self.runInteraction(
                     desc,
                     self.simple_upsert_txn,
@@ -840,6 +912,7 @@ class DatabasePool:
                     values,
                     insertion_values,
                     lock=lock,
+                    db_autocommit=autocommit,
                 )
             except self.engine.module.IntegrityError as e:
                 attempts += 1
@@ -1002,6 +1075,43 @@ class DatabasePool:
         )
         txn.execute(sql, list(allvalues.values()))
 
+    async def simple_upsert_many(
+        self,
+        table: str,
+        key_names: Collection[str],
+        key_values: Collection[Iterable[Any]],
+        value_names: Collection[str],
+        value_values: Iterable[Iterable[Any]],
+        desc: str,
+    ) -> None:
+        """
+        Upsert, many times.
+
+        Args:
+            table: The table to upsert into
+            key_names: The key column names.
+            key_values: A list of each row's key column values.
+            value_names: The value column names
+            value_values: A list of each row's value column values.
+                Ignored if value_names is empty.
+        """
+
+        # We can autocommit if we are going to use native upserts
+        autocommit = (
+            self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables
+        )
+
+        return await self.runInteraction(
+            desc,
+            self.simple_upsert_many_txn,
+            table,
+            key_names,
+            key_values,
+            value_names,
+            value_values,
+            db_autocommit=autocommit,
+        )
+
     def simple_upsert_many_txn(
         self,
         txn: LoggingTransaction,
@@ -1153,7 +1263,13 @@ class DatabasePool:
             desc: description of the transaction, for logging and metrics
         """
         return await self.runInteraction(
-            desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
+            desc,
+            self.simple_select_one_txn,
+            table,
+            keyvalues,
+            retcols,
+            allow_none,
+            db_autocommit=True,
         )
 
     @overload
@@ -1204,6 +1320,7 @@ class DatabasePool:
             keyvalues,
             retcol,
             allow_none=allow_none,
+            db_autocommit=True,
         )
 
     @overload
@@ -1285,7 +1402,12 @@ class DatabasePool:
             Results in a list
         """
         return await self.runInteraction(
-            desc, self.simple_select_onecol_txn, table, keyvalues, retcol
+            desc,
+            self.simple_select_onecol_txn,
+            table,
+            keyvalues,
+            retcol,
+            db_autocommit=True,
         )
 
     async def simple_select_list(
@@ -1310,7 +1432,12 @@ class DatabasePool:
             A list of dictionaries.
         """
         return await self.runInteraction(
-            desc, self.simple_select_list_txn, table, keyvalues, retcols
+            desc,
+            self.simple_select_list_txn,
+            table,
+            keyvalues,
+            retcols,
+            db_autocommit=True,
         )
 
     @classmethod
@@ -1389,6 +1516,7 @@ class DatabasePool:
                 chunk,
                 keyvalues,
                 retcols,
+                db_autocommit=True,
             )
 
             results.extend(rows)
@@ -1487,7 +1615,12 @@ class DatabasePool:
             desc: description of the transaction, for logging and metrics
         """
         await self.runInteraction(
-            desc, self.simple_update_one_txn, table, keyvalues, updatevalues
+            desc,
+            self.simple_update_one_txn,
+            table,
+            keyvalues,
+            updatevalues,
+            db_autocommit=True,
         )
 
     @classmethod
@@ -1546,7 +1679,9 @@ class DatabasePool:
             keyvalues: dict of column names and values to select the row with
             desc: description of the transaction, for logging and metrics
         """
-        await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
+        await self.runInteraction(
+            desc, self.simple_delete_one_txn, table, keyvalues, db_autocommit=True,
+        )
 
     @staticmethod
     def simple_delete_one_txn(
@@ -1585,7 +1720,9 @@ class DatabasePool:
         Returns:
             The number of deleted rows.
         """
-        return await self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
+        return await self.runInteraction(
+            desc, self.simple_delete_txn, table, keyvalues, db_autocommit=True
+        )
 
     @staticmethod
     def simple_delete_txn(
@@ -1633,7 +1770,13 @@ class DatabasePool:
             Number rows deleted
         """
         return await self.runInteraction(
-            desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
+            desc,
+            self.simple_delete_many_txn,
+            table,
+            column,
+            iterable,
+            keyvalues,
+            db_autocommit=True,
         )
 
     @staticmethod
@@ -1678,7 +1821,7 @@ class DatabasePool:
 
     def get_cache_dict(
         self,
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
         table: str,
         entity_column: str,
         stream_column: str,
@@ -1699,9 +1842,7 @@ class DatabasePool:
             "limit": limit,
         }
 
-        sql = self.engine.convert_param_style(sql)
-
-        txn = db_conn.cursor()
+        txn = db_conn.cursor(txn_name="get_cache_dict")
         txn.execute(sql, (int(max_value),))
 
         cache = {row[0]: int(row[1]) for row in txn}
@@ -1801,7 +1942,13 @@ class DatabasePool:
         """
 
         return await self.runInteraction(
-            desc, self.simple_search_list_txn, table, term, col, retcols
+            desc,
+            self.simple_search_list_txn,
+            table,
+            term,
+            col,
+            retcols,
+            db_autocommit=True,
         )
 
     @classmethod