summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/database.py86
-rw-r--r--synapse/storage/databases/main/room.py30
2 files changed, 73 insertions, 43 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a761ad603b..974703d13a 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -40,6 +40,7 @@ from twisted.enterprise import adbapi
 
 from synapse.api.errors import StoreError
 from synapse.config.database import DatabaseConnectionConfig
+from synapse.logging import opentracing
 from synapse.logging.context import (
     LoggingContext,
     current_context,
@@ -313,7 +314,14 @@ class LoggingTransaction:
         start = time.time()
 
         try:
-            return func(sql, *args)
+            with opentracing.start_active_span(
+                "db.query",
+                tags={
+                    opentracing.tags.DATABASE_TYPE: "sql",
+                    opentracing.tags.DATABASE_STATEMENT: sql,
+                },
+            ):
+                return func(sql, *args)
         except Exception as e:
             sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e)
             raise
@@ -525,9 +533,16 @@ class DatabasePool:
                     exception_callbacks=exception_callbacks,
                 )
                 try:
-                    r = func(cursor, *args, **kwargs)
-                    conn.commit()
-                    return r
+                    with opentracing.start_active_span(
+                        "db.txn",
+                        tags={
+                            opentracing.SynapseTags.DB_TXN_DESC: desc,
+                            opentracing.SynapseTags.DB_TXN_ID: name,
+                        },
+                    ):
+                        r = func(cursor, *args, **kwargs)
+                        conn.commit()
+                        return r
                 except self.engine.module.OperationalError as e:
                     # This can happen if the database disappears mid
                     # transaction.
@@ -653,16 +668,17 @@ class DatabasePool:
             logger.warning("Starting db txn '%s' from sentinel context", desc)
 
         try:
-            result = await self.runWithConnection(
-                self.new_transaction,
-                desc,
-                after_callbacks,
-                exception_callbacks,
-                func,
-                *args,
-                db_autocommit=db_autocommit,
-                **kwargs,
-            )
+            with opentracing.start_active_span(f"db.{desc}"):
+                result = await self.runWithConnection(
+                    self.new_transaction,
+                    desc,
+                    after_callbacks,
+                    exception_callbacks,
+                    func,
+                    *args,
+                    db_autocommit=db_autocommit,
+                    **kwargs,
+                )
 
             for after_callback, after_args, after_kwargs in after_callbacks:
                 after_callback(*after_args, **after_kwargs)
@@ -718,25 +734,29 @@ class DatabasePool:
             with LoggingContext(
                 str(curr_context), parent_context=parent_context
             ) as context:
-                sched_duration_sec = monotonic_time() - start_time
-                sql_scheduling_timer.observe(sched_duration_sec)
-                context.add_database_scheduled(sched_duration_sec)
-
-                if self.engine.is_connection_closed(conn):
-                    logger.debug("Reconnecting closed database connection")
-                    conn.reconnect()
-
-                try:
-                    if db_autocommit:
-                        self.engine.attempt_to_set_autocommit(conn, True)
-
-                    db_conn = LoggingDatabaseConnection(
-                        conn, self.engine, "runWithConnection"
-                    )
-                    return func(db_conn, *args, **kwargs)
-                finally:
-                    if db_autocommit:
-                        self.engine.attempt_to_set_autocommit(conn, False)
+                with opentracing.start_active_span(
+                    operation_name="db.connection",
+                ):
+                    sched_duration_sec = monotonic_time() - start_time
+                    sql_scheduling_timer.observe(sched_duration_sec)
+                    context.add_database_scheduled(sched_duration_sec)
+
+                    if self.engine.is_connection_closed(conn):
+                        logger.debug("Reconnecting closed database connection")
+                        conn.reconnect()
+                        opentracing.log_kv({"message": "reconnected"})
+
+                    try:
+                        if db_autocommit:
+                            self.engine.attempt_to_set_autocommit(conn, True)
+
+                        db_conn = LoggingDatabaseConnection(
+                            conn, self.engine, "runWithConnection"
+                        )
+                        return func(db_conn, *args, **kwargs)
+                    finally:
+                        if db_autocommit:
+                            self.engine.attempt_to_set_autocommit(conn, False)
 
         return await make_deferred_yieldable(
             self._db_pool.runWithConnection(inner_func, *args, **kwargs)
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 0cf450f81d..2a96bcd314 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -764,14 +764,15 @@ class RoomWorkerStore(SQLBaseStore):
         self,
         server_name: str,
         media_id: str,
-        quarantined_by: str,
+        quarantined_by: Optional[str],
     ) -> int:
-        """quarantines a single local or remote media id
+        """quarantines or unquarantines a single local or remote media id
 
         Args:
             server_name: The name of the server that holds this media
             media_id: The ID of the media to be quarantined
             quarantined_by: The user ID that initiated the quarantine request
+                If it is `None` media will be removed from quarantine
         """
         logger.info("Quarantining media: %s/%s", server_name, media_id)
         is_local = server_name == self.config.server_name
@@ -838,9 +839,9 @@ class RoomWorkerStore(SQLBaseStore):
         txn,
         local_mxcs: List[str],
         remote_mxcs: List[Tuple[str, str]],
-        quarantined_by: str,
+        quarantined_by: Optional[str],
     ) -> int:
-        """Quarantine local and remote media items
+        """Quarantine and unquarantine local and remote media items
 
         Args:
             txn (cursor)
@@ -848,18 +849,27 @@ class RoomWorkerStore(SQLBaseStore):
             remote_mxcs: A list of (remote server, media id) tuples representing
                 remote mxc URLs
             quarantined_by: The ID of the user who initiated the quarantine request
+                If it is `None` media will be removed from quarantine
         Returns:
             The total number of media items quarantined
         """
+
         # Update all the tables to set the quarantined_by flag
-        txn.executemany(
-            """
+        sql = """
             UPDATE local_media_repository
             SET quarantined_by = ?
-            WHERE media_id = ? AND safe_from_quarantine = ?
-        """,
-            ((quarantined_by, media_id, False) for media_id in local_mxcs),
-        )
+            WHERE media_id = ?
+        """
+
+        # set quarantine
+        if quarantined_by is not None:
+            sql += "AND safe_from_quarantine = ?"
+            rows = [(quarantined_by, media_id, False) for media_id in local_mxcs]
+        # remove from quarantine
+        else:
+            rows = [(quarantined_by, media_id) for media_id in local_mxcs]
+
+        txn.executemany(sql, rows)
         # Note that a rowcount of -1 can be used to indicate no rows were affected.
         total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0