summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/10440.feature1
-rw-r--r--docs/sample_config.yaml4
-rw-r--r--synapse/config/database.py4
-rw-r--r--synapse/storage/database.py21
-rw-r--r--tests/storage/test_txn_limit.py36
-rw-r--r--tests/utils.py3
6 files changed, 69 insertions, 0 deletions
diff --git a/changelog.d/10440.feature b/changelog.d/10440.feature
new file mode 100644
index 0000000000..f1833b0bd7
--- /dev/null
+++ b/changelog.d/10440.feature
@@ -0,0 +1 @@
+Allow setting transaction limit for database connections.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 853c2f6899..1a217f35db 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -720,6 +720,9 @@ caches:
 # 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or
 # 'psycopg2' (for PostgreSQL).
 #
+# 'txn_limit' gives the maximum number of transactions to run per connection
+# before reconnecting. Defaults to 0, which means no limit.
+#
 # 'args' gives options which are passed through to the database engine,
 # except for options starting 'cp_', which are used to configure the Twisted
 # connection pool. For a reference to valid arguments, see:
@@ -740,6 +743,7 @@ caches:
 #
 #database:
 #  name: psycopg2
+#  txn_limit: 10000
 #  args:
 #    user: synapse_user
 #    password: secretpassword
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 3d7d92f615..651e31b576 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -33,6 +33,9 @@ DEFAULT_CONFIG = """\
 # 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or
 # 'psycopg2' (for PostgreSQL).
 #
+# 'txn_limit' gives the maximum number of transactions to run per connection
+# before reconnecting. Defaults to 0, which means no limit.
+#
 # 'args' gives options which are passed through to the database engine,
 # except for options starting 'cp_', which are used to configure the Twisted
 # connection pool. For a reference to valid arguments, see:
@@ -53,6 +56,7 @@ DEFAULT_CONFIG = """\
 #
 #database:
 #  name: psycopg2
+#  txn_limit: 10000
 #  args:
 #    user: synapse_user
 #    password: secretpassword
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 4d4643619f..c8015a3848 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 import logging
 import time
+from collections import defaultdict
 from sys import intern
 from time import monotonic as monotonic_time
 from typing import (
@@ -397,6 +398,7 @@ class DatabasePool:
     ):
         self.hs = hs
         self._clock = hs.get_clock()
+        self._txn_limit = database_config.config.get("txn_limit", 0)
         self._database_config = database_config
         self._db_pool = make_pool(hs.get_reactor(), database_config, engine)
 
@@ -406,6 +408,9 @@ class DatabasePool:
         self._current_txn_total_time = 0.0
         self._previous_loop_ts = 0.0
 
+        # Transaction counter: key is the twisted thread id, value is the current count
+        self._txn_counters: Dict[int, int] = defaultdict(int)
+
         # TODO(paul): These can eventually be removed once the metrics code
         #   is running in mainline, and we have some nice monitoring frontends
         #   to watch it
@@ -750,10 +755,26 @@ class DatabasePool:
                     sql_scheduling_timer.observe(sched_duration_sec)
                     context.add_database_scheduled(sched_duration_sec)
 
+                    if self._txn_limit > 0:
+                        tid = self._db_pool.threadID()
+                        self._txn_counters[tid] += 1
+
+                        if self._txn_counters[tid] > self._txn_limit:
+                            logger.debug(
+                                "Reconnecting database connection over transaction limit"
+                            )
+                            conn.reconnect()
+                            opentracing.log_kv(
+                                {"message": "reconnected due to txn limit"}
+                            )
+                            self._txn_counters[tid] = 1
+
                     if self.engine.is_connection_closed(conn):
                         logger.debug("Reconnecting closed database connection")
                         conn.reconnect()
                         opentracing.log_kv({"message": "reconnected"})
+                        if self._txn_limit > 0:
+                            self._txn_counters[tid] = 1
 
                     try:
                         if db_autocommit:
diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py
new file mode 100644
index 0000000000..9be51f9ebd
--- /dev/null
+++ b/tests/storage/test_txn_limit.py
@@ -0,0 +1,36 @@
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from tests import unittest
+
+
+class SQLTransactionLimitTestCase(unittest.HomeserverTestCase):
+    """Test SQL transaction limit doesn't break transactions."""
+
+    def make_homeserver(self, reactor, clock):
+        return self.setup_test_homeserver(db_txn_limit=1000)
+
+    def test_config(self):
+        db_config = self.hs.config.get_single_database()
+        self.assertEqual(db_config.config["txn_limit"], 1000)
+
+    def test_select(self):
+        def do_select(txn):
+            txn.execute("SELECT 1")
+
+        db_pool = self.hs.get_datastores().databases[0]
+
+        # force txn limit to roll over at least once
+        for i in range(0, 1001):
+            self.get_success_or_raise(db_pool.runInteraction("test_select", do_select))
diff --git a/tests/utils.py b/tests/utils.py
index 6bd008dcfe..f3458ca88d 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -239,6 +239,9 @@ def setup_test_homeserver(
             "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
         }
 
+    if "db_txn_limit" in kwargs:
+        database_config["txn_limit"] = kwargs["db_txn_limit"]
+
     database = DatabaseConnectionConfig("master", database_config)
     config.database.databases = [database]