summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-09-28 18:00:30 +0100
committerGitHub <noreply@github.com>2020-09-28 18:00:30 +0100
commitbd380d942fdf91cf1214d6859f2bc97d12a92ab4 (patch)
tree515186a89d274f7d4272f4cdcb0e9698dac7e2ef /tests
parentCreate a mechanism for marking tests "logcontext clean" (#8399) (diff)
downloadsynapse-bd380d942fdf91cf1214d6859f2bc97d12a92ab4.tar.xz
Add checks for postgres sequence consistency (#8402)
Diffstat (limited to 'tests')
-rw-r--r--tests/storage/test_id_generators.py22
-rw-r--r--tests/unittest.py31
2 files changed, 49 insertions, 4 deletions
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index d4ff55fbff..4558bee7be 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -12,9 +12,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-
 from synapse.storage.database import DatabasePool
+from synapse.storage.engines import IncorrectDatabaseSetup
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
 
 from tests.unittest import HomeserverTestCase
@@ -59,7 +58,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                 writers=writers,
             )
 
-        return self.get_success(self.db_pool.runWithConnection(_create))
+        return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
 
     def _insert_rows(self, instance_name: str, number: int):
         """Insert N rows as the given instance, inserting with stream IDs pulled
@@ -411,6 +410,23 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.get_success(_get_next_async())
         self.assertEqual(id_gen_3.get_persisted_upto_position(), 6)
 
+    def test_sequence_consistency(self):
+        """Test that we error out if the table and sequence diverges.
+        """
+
+        # Prefill with some rows
+        self._insert_row_with_id("master", 3)
+
+        # Now we add a row *without* updating the stream ID
+        def _insert(txn):
+            txn.execute("INSERT INTO foobar VALUES (26, 'master')")
+
+        self.get_success(self.db_pool.runInteraction("_insert", _insert))
+
+        # Creating the ID gen should error
+        with self.assertRaises(IncorrectDatabaseSetup):
+            self._create_id_generator("first")
+
 
 class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
     """Tests MultiWriterIdGenerator that produce *negative* stream IDs.
diff --git a/tests/unittest.py b/tests/unittest.py
index bbe50c3851..e654c0442d 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -14,7 +14,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import gc
 import hashlib
 import hmac
@@ -28,6 +27,7 @@ from mock import Mock, patch
 from canonicaljson import json
 
 from twisted.internet.defer import Deferred, ensureDeferred, succeed
+from twisted.python.failure import Failure
 from twisted.python.threadpool import ThreadPool
 from twisted.trial import unittest
 
@@ -476,6 +476,35 @@ class HomeserverTestCase(TestCase):
         self.pump()
         return self.failureResultOf(d, exc)
 
+    def get_success_or_raise(self, d, by=0.0):
+        """Drive deferred to completion and return result or raise exception
+        on failure.
+        """
+
+        if inspect.isawaitable(d):
+            deferred = ensureDeferred(d)
+        if not isinstance(deferred, Deferred):
+            return d
+
+        results = []  # type: list
+        deferred.addBoth(results.append)
+
+        self.pump(by=by)
+
+        if not results:
+            self.fail(
+                "Success result expected on {!r}, found no result instead".format(
+                    deferred
+                )
+            )
+
+        result = results[0]
+
+        if isinstance(result, Failure):
+            result.raiseException()
+
+        return result
+
     def register_user(self, username, password, admin=False):
         """
         Register a user. Requires the Admin API be registered.