summary refs log tree commit diff
path: root/synapse/storage/util/sequence.py
diff options
context:
space:
mode:
authorBen Banfield-Zanin <benbz@matrix.org>2021-03-01 10:06:09 +0000
committerBen Banfield-Zanin <benbz@matrix.org>2021-03-01 10:06:09 +0000
commitb26bee9faf957643cd34c4146b250b0009be205d (patch)
treea7a7e29f30acb437d010bdf6116c0f2729f21a1b /synapse/storage/util/sequence.py
parentMerge remote-tracking branch 'origin/release-v1.26.0' into toml/keycloak_hints (diff)
parentFixup changelog (diff)
downloadsynapse-b26bee9faf957643cd34c4146b250b0009be205d.tar.xz
Merge remote-tracking branch 'origin/release-v1.28.0' into toml/keycloak_hints github/toml/keycloak_hints toml/keycloak_hints
Diffstat (limited to 'synapse/storage/util/sequence.py')
-rw-r--r--synapse/storage/util/sequence.py27
1 files changed, 23 insertions, 4 deletions
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index c780ade077..3ea637b281 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -70,6 +70,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
         ...
 
     @abc.abstractmethod
+    def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+        """Get the next `n` IDs in the sequence"""
+        ...
+
+    @abc.abstractmethod
     def check_consistency(
         self,
         db_conn: "LoggingDatabaseConnection",
@@ -101,7 +106,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
 
     def get_next_id_txn(self, txn: Cursor) -> int:
         txn.execute("SELECT nextval(?)", (self._sequence_name,))
-        return txn.fetchone()[0]
+        fetch_res = txn.fetchone()
+        assert fetch_res is not None
+        return fetch_res[0]
 
     def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
         txn.execute(
@@ -117,8 +124,7 @@ class PostgresSequenceGenerator(SequenceGenerator):
         stream_name: Optional[str] = None,
         positive: bool = True,
     ):
-        """See SequenceGenerator.check_consistency for docstring.
-        """
+        """See SequenceGenerator.check_consistency for docstring."""
 
         txn = db_conn.cursor(txn_name="sequence.check_consistency")
 
@@ -142,7 +148,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
         txn.execute(
             "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
         )
-        last_value, is_called = txn.fetchone()
+        fetch_res = txn.fetchone()
+        assert fetch_res is not None
+        last_value, is_called = fetch_res
 
         # If we have an associated stream check the stream_positions table.
         max_in_stream_positions = None
@@ -219,6 +227,17 @@ class LocalSequenceGenerator(SequenceGenerator):
             self._current_max_id += 1
             return self._current_max_id
 
+    def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+        with self._lock:
+            if self._current_max_id is None:
+                assert self._callback is not None
+                self._current_max_id = self._callback(txn)
+                self._callback = None
+
+            first_id = self._current_max_id + 1
+            self._current_max_id += n
+            return [first_id + i for i in range(n)]
+
     def check_consistency(
         self,
         db_conn: Connection,