summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2020-07-16 10:56:49 +0100
committerRichard van der Hoff <richard@matrix.org>2020-07-16 11:25:08 +0100
commit42509b8fb614173e4ef51e12e48178c89f61e662 (patch)
tree8854b1efd61c33fd14cc59d82e3f6ebf31f4176b
parentAdd some helper classes for generating ID sequences (diff)
downloadsynapse-42509b8fb614173e4ef51e12e48178c89f61e662.tar.xz
Use `PostgresSequenceGenerator` from `MultiWriterIdGenerator`
partly just to show it works, but alwo to remove a bit of code duplication.
-rw-r--r--synapse/storage/util/id_generators.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f89ce0bed2..787cebfbec 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -21,6 +21,7 @@ from typing import Dict, Set, Tuple
 from typing_extensions import Deque
 
 from synapse.storage.database import Database, LoggingTransaction
+from synapse.storage.util.sequence import PostgresSequenceGenerator
 
 
 class IdGenerator(object):
@@ -247,7 +248,6 @@ class MultiWriterIdGenerator:
     ):
         self._db = db
         self._instance_name = instance_name
-        self._sequence_name = sequence_name
 
         # We lock as some functions may be called from DB threads.
         self._lock = threading.Lock()
@@ -260,6 +260,8 @@ class MultiWriterIdGenerator:
         # should be less than the minimum of this set (if not empty).
         self._unfinished_ids = set()  # type: Set[int]
 
+        self._sequence_gen = PostgresSequenceGenerator(sequence_name)
+
     def _load_current_ids(
         self, db_conn, table: str, instance_column: str, id_column: str
     ) -> Dict[str, int]:
@@ -283,9 +285,7 @@ class MultiWriterIdGenerator:
         return current_positions
 
     def _load_next_id_txn(self, txn):
-        txn.execute("SELECT nextval(?)", (self._sequence_name,))
-        (next_id,) = txn.fetchone()
-        return next_id
+        return self._sequence_gen.get_next_id_txn(txn)
 
     async def get_next(self):
         """