summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/data_stores/state/store.py12
-rw-r--r--synapse/storage/engines/_base.py6
-rw-r--r--synapse/storage/engines/postgres.py6
-rw-r--r--synapse/storage/engines/sqlite.py13
4 files changed, 11 insertions, 26 deletions
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index 5db9f20135..128c09a2cf 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/data_stores/state/store.py
@@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
 from synapse.storage.database import Database
 from synapse.storage.state import StateFilter
+from synapse.storage.types import Cursor
+from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import StateMap
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.dictionary_cache import DictionaryCache
@@ -92,6 +94,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             "*stateGroupMembersCache*", 500000,
         )
 
+        def get_max_state_group_txn(txn: Cursor):
+            txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
+            return txn.fetchone()[0]
+
+        self._state_group_seq_gen = build_sequence_generator(
+            self.database_engine, get_max_state_group_txn, "state_group_id_seq"
+        )
+
     @cached(max_entries=10000, iterable=True)
     def get_state_group_delta(self, state_group):
         """Given a state group try to return a previous group and a delta between
@@ -386,7 +396,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 # AFAIK, this can never happen
                 raise Exception("current_state_ids cannot be None")
 
-            state_group = self.database_engine.get_next_state_group_id(txn)
+            state_group = self._state_group_seq_gen.get_next_id_txn(txn)
 
             self.db.simple_insert_txn(
                 txn,
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index ab0bbe4bd3..908cbc79e3 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -91,12 +91,6 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
     def lock_table(self, txn, table: str) -> None:
         ...
 
-    @abc.abstractmethod
-    def get_next_state_group_id(self, txn) -> int:
-        """Returns an int that can be used as a new state_group ID
-        """
-        ...
-
     @property
     @abc.abstractmethod
     def server_version(self) -> str:
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a31588080d..ff39281f85 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -154,12 +154,6 @@ class PostgresEngine(BaseDatabaseEngine):
     def lock_table(self, txn, table):
         txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
 
-    def get_next_state_group_id(self, txn):
-        """Returns an int that can be used as a new state_group ID
-        """
-        txn.execute("SELECT nextval('state_group_id_seq')")
-        return txn.fetchone()[0]
-
     @property
     def server_version(self):
         """Returns a string giving the server version. For example: '8.1.5'
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 215a949442..8a0f8c89d1 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -96,19 +96,6 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
     def lock_table(self, txn, table):
         return
 
-    def get_next_state_group_id(self, txn):
-        """Returns an int that can be used as a new state_group ID
-        """
-        # We do application locking here since if we're using sqlite then
-        # we are a single process synapse.
-        with self._current_state_group_id_lock:
-            if self._current_state_group_id is None:
-                txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
-                self._current_state_group_id = txn.fetchone()[0]
-
-            self._current_state_group_id += 1
-            return self._current_state_group_id
-
     @property
     def server_version(self):
         """Gets a string giving the server version. For example: '3.22.0'