diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index d4ff55fbff..4558bee7be 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -12,9 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-
from synapse.storage.database import DatabasePool
+from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from tests.unittest import HomeserverTestCase
@@ -59,7 +58,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
writers=writers,
)
- return self.get_success(self.db_pool.runWithConnection(_create))
+ return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
def _insert_rows(self, instance_name: str, number: int):
"""Insert N rows as the given instance, inserting with stream IDs pulled
@@ -411,6 +410,23 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(_get_next_async())
self.assertEqual(id_gen_3.get_persisted_upto_position(), 6)
+ def test_sequence_consistency(self):
+ """Test that we error out if the table and sequence diverges.
+ """
+
+ # Prefill with some rows
+ self._insert_row_with_id("master", 3)
+
+ # Now we add a row *without* updating the stream ID
+ def _insert(txn):
+ txn.execute("INSERT INTO foobar VALUES (26, 'master')")
+
+ self.get_success(self.db_pool.runInteraction("_insert", _insert))
+
+ # Creating the ID gen should error
+ with self.assertRaises(IncorrectDatabaseSetup):
+ self._create_id_generator("first")
+
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs.
diff --git a/tests/unittest.py b/tests/unittest.py
index bbe50c3851..e654c0442d 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import gc
import hashlib
import hmac
@@ -28,6 +27,7 @@ from mock import Mock, patch
from canonicaljson import json
from twisted.internet.defer import Deferred, ensureDeferred, succeed
+from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
@@ -476,6 +476,35 @@ class HomeserverTestCase(TestCase):
self.pump()
return self.failureResultOf(d, exc)
+ def get_success_or_raise(self, d, by=0.0):
+ """Drive deferred to completion and return result or raise exception
+ on failure.
+ """
+
+ if inspect.isawaitable(d):
+ deferred = ensureDeferred(d)
+ if not isinstance(deferred, Deferred):
+ return d
+
+ results = [] # type: list
+ deferred.addBoth(results.append)
+
+ self.pump(by=by)
+
+ if not results:
+ self.fail(
+ "Success result expected on {!r}, found no result instead".format(
+ deferred
+ )
+ )
+
+ result = results[0]
+
+ if isinstance(result, Failure):
+ result.raiseException()
+
+ return result
+
def register_user(self, username, password, admin=False):
"""
Register a user. Requires the Admin API be registered.
|