summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8456.misc1
-rw-r--r--synapse/storage/database.py69
-rw-r--r--synapse/storage/engines/_base.py17
-rw-r--r--synapse/storage/engines/postgres.py10
-rw-r--r--synapse/storage/engines/sqlite.py10
-rw-r--r--synapse/storage/util/id_generators.py12
-rw-r--r--tests/storage/test_base.py1
7 files changed, 112 insertions, 8 deletions
diff --git a/changelog.d/8456.misc b/changelog.d/8456.misc
new file mode 100644
index 0000000000..ccd260069b
--- /dev/null
+++ b/changelog.d/8456.misc
@@ -0,0 +1 @@
+Reduce number of serialization errors of `MultiWriterIdGenerator._update_table`.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 0d9d9b7cc0..0ba3a025cf 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -463,6 +463,24 @@ class DatabasePool:
         *args: Any,
         **kwargs: Any
     ) -> R:
+        """Start a new database transaction with the given connection.
+
+        Note: The given func may be called multiple times under certain
+        failure modes. This is normally fine when in a standard transaction,
+        but care must be taken if the connection is in `autocommit` mode that
+        the function will correctly handle being aborted and retried half way
+        through its execution.
+
+        Args:
+            conn
+            desc
+            after_callbacks
+            exception_callbacks
+            func
+            *args
+            **kwargs
+        """
+
         start = monotonic_time()
         txn_id = self._TXN_ID
 
@@ -566,7 +584,12 @@ class DatabasePool:
             sql_txn_timer.labels(desc).observe(duration)
 
     async def runInteraction(
-        self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any
+        self,
+        desc: str,
+        func: "Callable[..., R]",
+        *args: Any,
+        db_autocommit: bool = False,
+        **kwargs: Any
     ) -> R:
         """Starts a transaction on the database and runs a given function
 
@@ -576,6 +599,18 @@ class DatabasePool:
                 database transaction (twisted.enterprise.adbapi.Transaction) as
                 its first argument, followed by `args` and `kwargs`.
 
+            db_autocommit: Whether to run the function in "autocommit" mode,
+                i.e. outside of a transaction. This is useful for transactions
+                that are only a single query.
+
+                Currently, this is only implemented for Postgres. SQLite will still
+                run the function inside a transaction.
+
+                WARNING: This means that if func fails half way through then
+                the changes will *not* be rolled back. `func` may also get
+                called multiple times if the transaction is retried, so must
+                correctly handle that case.
+
             args: positional args to pass to `func`
             kwargs: named args to pass to `func`
 
@@ -596,6 +631,7 @@ class DatabasePool:
                 exception_callbacks,
                 func,
                 *args,
+                db_autocommit=db_autocommit,
                 **kwargs
             )
 
@@ -609,7 +645,11 @@ class DatabasePool:
         return cast(R, result)
 
     async def runWithConnection(
-        self, func: "Callable[..., R]", *args: Any, **kwargs: Any
+        self,
+        func: "Callable[..., R]",
+        *args: Any,
+        db_autocommit: bool = False,
+        **kwargs: Any
     ) -> R:
         """Wraps the .runWithConnection() method on the underlying db_pool.
 
@@ -618,6 +658,9 @@ class DatabasePool:
                 database connection (twisted.enterprise.adbapi.Connection) as
                 its first argument, followed by `args` and `kwargs`.
             args: positional args to pass to `func`
+            db_autocommit: Whether to run the function in "autocommit" mode,
+                i.e. outside of a transaction. This is useful for transaction
+                that are only a single query. Currently only affects postgres.
             kwargs: named args to pass to `func`
 
         Returns:
@@ -633,6 +676,13 @@ class DatabasePool:
         start_time = monotonic_time()
 
         def inner_func(conn, *args, **kwargs):
+            # We shouldn't be in a transaction. If we are then something
+            # somewhere hasn't committed after doing work. (This is likely only
+            # possible during startup, as `run*` will ensure changes are
+            # committed/rolled back before putting the connection back in the
+            # pool).
+            assert not self.engine.in_transaction(conn)
+
             with LoggingContext("runWithConnection", parent_context) as context:
                 sched_duration_sec = monotonic_time() - start_time
                 sql_scheduling_timer.observe(sched_duration_sec)
@@ -642,10 +692,17 @@ class DatabasePool:
                     logger.debug("Reconnecting closed database connection")
                     conn.reconnect()
 
-                db_conn = LoggingDatabaseConnection(
-                    conn, self.engine, "runWithConnection"
-                )
-                return func(db_conn, *args, **kwargs)
+                try:
+                    if db_autocommit:
+                        self.engine.attempt_to_set_autocommit(conn, True)
+
+                    db_conn = LoggingDatabaseConnection(
+                        conn, self.engine, "runWithConnection"
+                    )
+                    return func(db_conn, *args, **kwargs)
+                finally:
+                    if db_autocommit:
+                        self.engine.attempt_to_set_autocommit(conn, False)
 
         return await make_deferred_yieldable(
             self._db_pool.runWithConnection(inner_func, *args, **kwargs)
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index 908cbc79e3..d6d632dc10 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -97,3 +97,20 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
         """Gets a string giving the server version. For example: '3.22.0'
         """
         ...
+
+    @abc.abstractmethod
+    def in_transaction(self, conn: Connection) -> bool:
+        """Whether the connection is currently in a transaction.
+        """
+        ...
+
+    @abc.abstractmethod
+    def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
+        """Attempt to set the connections autocommit mode.
+
+        When True queries are run outside of transactions.
+
+        Note: This has no effect on SQLite3, so callers still need to
+        commit/rollback the connections.
+        """
+        ...
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index ff39281f85..7719ac32f7 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -15,7 +15,8 @@
 
 import logging
 
-from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
+from synapse.storage.engines._base import BaseDatabaseEngine, IncorrectDatabaseSetup
+from synapse.storage.types import Connection
 
 logger = logging.getLogger(__name__)
 
@@ -119,6 +120,7 @@ class PostgresEngine(BaseDatabaseEngine):
             cursor.execute("SET synchronous_commit TO OFF")
 
         cursor.close()
+        db_conn.commit()
 
     @property
     def can_native_upsert(self):
@@ -171,3 +173,9 @@ class PostgresEngine(BaseDatabaseEngine):
             return "%i.%i" % (numver / 10000, numver % 10000)
         else:
             return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)
+
+    def in_transaction(self, conn: Connection) -> bool:
+        return conn.status != self.module.extensions.STATUS_READY  # type: ignore
+
+    def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
+        return conn.set_session(autocommit=autocommit)  # type: ignore
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 8a0f8c89d1..5db0f0b520 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -17,6 +17,7 @@ import threading
 import typing
 
 from synapse.storage.engines import BaseDatabaseEngine
+from synapse.storage.types import Connection
 
 if typing.TYPE_CHECKING:
     import sqlite3  # noqa: F401
@@ -86,6 +87,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
 
         db_conn.create_function("rank", 1, _rank)
         db_conn.execute("PRAGMA foreign_keys = ON;")
+        db_conn.commit()
 
     def is_deadlock(self, error):
         return False
@@ -105,6 +107,14 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
         """
         return "%i.%i.%i" % self.module.sqlite_version_info
 
+    def in_transaction(self, conn: Connection) -> bool:
+        return conn.in_transaction  # type: ignore
+
+    def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
+        # Twisted doesn't let us set attributes on the connections, so we can't
+        # set the connection to autocommit mode.
+        pass
+
 
 # Following functions taken from: https://github.com/coleifer/peewee
 
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 51f680d05d..d7e40aaa8b 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -24,6 +24,7 @@ from typing_extensions import Deque
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.types import Cursor
 from synapse.storage.util.sequence import PostgresSequenceGenerator
 
 logger = logging.getLogger(__name__)
@@ -548,7 +549,7 @@ class MultiWriterIdGenerator:
                 # do.
                 break
 
-    def _update_stream_positions_table_txn(self, txn):
+    def _update_stream_positions_table_txn(self, txn: Cursor):
         """Update the `stream_positions` table with newly persisted position.
         """
 
@@ -598,10 +599,13 @@ class _MultiWriterCtxManager:
     stream_ids = attr.ib(type=List[int], factory=list)
 
     async def __aenter__(self) -> Union[int, List[int]]:
+        # It's safe to run this in autocommit mode as fetching values from a
+        # sequence ignores transaction semantics anyway.
         self.stream_ids = await self.id_gen._db.runInteraction(
             "_load_next_mult_id",
             self.id_gen._load_next_mult_id_txn,
             self.multiple_ids or 1,
+            db_autocommit=True,
         )
 
         # Assert the fetched ID is actually greater than any ID we've already
@@ -632,10 +636,16 @@ class _MultiWriterCtxManager:
         #
         # We only do this on the success path so that the persisted current
         # position points to a persisted row with the correct instance name.
+        #
+        # We do this in autocommit mode as a) the upsert works correctly outside
+        # transactions and b) reduces the amount of time the rows are locked
+        # for. If we don't do this then we'll often hit serialization errors due
+        # to the fact we default to REPEATABLE READ isolation levels.
         if self.id_gen._writers:
             await self.id_gen._db.runInteraction(
                 "MultiWriterIdGenerator._update_table",
                 self.id_gen._update_stream_positions_table_txn,
+                db_autocommit=True,
             )
 
         return False
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 40ba652248..eac7e4dcd2 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -56,6 +56,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         engine = create_engine(sqlite_config)
         fake_engine = Mock(wraps=engine)
         fake_engine.can_native_upsert = False
+        fake_engine.in_transaction.return_value = False
 
         db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine)
         db._db_pool = self.db_pool