summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/databases/main/account_data.py11
-rw-r--r--synapse/storage/util/id_generators.py45
-rw-r--r--synapse/storage/util/sequence.py2
3 files changed, 47 insertions, 11 deletions
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 308d19440f..2d2ba74347 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -40,7 +40,6 @@ from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
-    AbstractStreamIdTracker,
     MultiWriterIdGenerator,
     StreamIdGenerator,
 )
@@ -64,14 +63,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
     ):
         super().__init__(database, db_conn, hs)
 
-        # `_can_write_to_account_data` indicates whether the current worker is allowed
-        # to write account data. A value of `True` implies that `_account_data_id_gen`
-        # is an `AbstractStreamIdGenerator` and not just a tracker.
-        self._account_data_id_gen: AbstractStreamIdTracker
         self._can_write_to_account_data = (
             self._instance_name in hs.config.worker.writers.account_data
         )
 
+        self._account_data_id_gen: AbstractStreamIdGenerator
+
         if isinstance(database.engine, PostgresEngine):
             self._account_data_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
@@ -558,7 +555,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             The maximum stream ID.
         """
         assert self._can_write_to_account_data
-        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         content_json = json_encoder.encode(content)
 
@@ -598,7 +594,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             data to delete.
         """
         assert self._can_write_to_account_data
-        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         def _remove_account_data_for_room_txn(
             txn: LoggingTransaction, next_id: int
@@ -663,7 +658,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             The maximum stream ID.
         """
         assert self._can_write_to_account_data
-        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         async with self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
@@ -770,7 +764,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             to delete.
         """
         assert self._can_write_to_account_data
-        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         def _remove_account_data_for_user_txn(
             txn: LoggingTransaction, next_id: int
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9adff3f4f5..334d3d718b 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -158,6 +158,15 @@ class AbstractStreamIdGenerator(AbstractStreamIdTracker):
         """
         raise NotImplementedError()
 
+    @abc.abstractmethod
+    def get_next_txn(self, txn: LoggingTransaction) -> int:
+        """
+        Usage:
+            stream_id_gen.get_next_txn(txn)
+            # ... persist events ...
+        """
+        raise NotImplementedError()
+
 
 class StreamIdGenerator(AbstractStreamIdGenerator):
     """Generates and tracks stream IDs for a stream with a single writer.
@@ -263,6 +272,40 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
 
         return _AsyncCtxManagerWrapper(manager())
 
+    def get_next_txn(self, txn: LoggingTransaction) -> int:
+        """
+        Retrieve the next stream ID from within a database transaction.
+
+        Clean-up functions will be called when the transaction finishes.
+
+        Args:
+            txn: The database transaction object.
+
+        Returns:
+            The next stream ID.
+        """
+        if not self._is_writer:
+            raise Exception("Tried to allocate stream ID on non-writer")
+
+        # Get the next stream ID.
+        with self._lock:
+            self._current += self._step
+            next_id = self._current
+
+            self._unfinished_ids[next_id] = next_id
+
+        def clear_unfinished_id(id_to_clear: int) -> None:
+            """A function to mark processing this ID as finished"""
+            with self._lock:
+                self._unfinished_ids.pop(id_to_clear)
+
+        # Mark this ID as finished once the database transaction itself finishes.
+        txn.call_after(clear_unfinished_id, next_id)
+        txn.call_on_exception(clear_unfinished_id, next_id)
+
+        # Return the new ID.
+        return next_id
+
     def get_current_token(self) -> int:
         if not self._is_writer:
             return self._current
@@ -568,7 +611,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         """
         Usage:
 
-            stream_id = stream_id_gen.get_next(txn)
+            stream_id = stream_id_gen.get_next_txn(txn)
             # ... persist event ...
         """
 
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 75268cbe15..80915216de 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -205,7 +205,7 @@ class LocalSequenceGenerator(SequenceGenerator):
         """
         Args:
             get_first_callback: a callback which is called on the first call to
-                 get_next_id_txn; should return the curreent maximum id
+                 get_next_id_txn; should return the current maximum id
         """
         # the callback. this is cleared after it is called, so that it can be GCed.
         self._callback: Optional[GetFirstCallbackType] = get_first_callback