summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/push_rule.py36
-rw-r--r--synapse/storage/util/id_generators.py68
2 files changed, 18 insertions, 86 deletions
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index c2289a9557..a585e54812 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -30,7 +30,7 @@ from synapse.storage.databases.main.pusher import PusherWorkerStore
 from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import ChainedIdGenerator
+from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -82,9 +82,9 @@ class PushRulesWorkerStore(
         super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
 
         if hs.config.worker.worker_app is None:
-            self._push_rules_stream_id_gen = ChainedIdGenerator(
-                self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
-            )  # type: Union[ChainedIdGenerator, SlavedIdTracker]
+            self._push_rules_stream_id_gen = StreamIdGenerator(
+                db_conn, "push_rules_stream", "stream_id"
+            )  # type: Union[StreamIdGenerator, SlavedIdTracker]
         else:
             self._push_rules_stream_id_gen = SlavedIdTracker(
                 db_conn, "push_rules_stream", "stream_id"
@@ -338,8 +338,9 @@ class PushRuleStore(PushRulesWorkerStore):
     ) -> None:
         conditions_json = json_encoder.encode(conditions)
         actions_json = json_encoder.encode(actions)
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
+        with self._push_rules_stream_id_gen.get_next() as stream_id:
+            event_stream_ordering = self._stream_id_gen.get_current_token()
+
             if before or after:
                 await self.db_pool.runInteraction(
                     "_add_push_rule_relative_txn",
@@ -559,8 +560,9 @@ class PushRuleStore(PushRulesWorkerStore):
                 txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
             )
 
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
+        with self._push_rules_stream_id_gen.get_next() as stream_id:
+            event_stream_ordering = self._stream_id_gen.get_current_token()
+
             await self.db_pool.runInteraction(
                 "delete_push_rule",
                 delete_push_rule_txn,
@@ -569,8 +571,9 @@ class PushRuleStore(PushRulesWorkerStore):
             )
 
     async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
+        with self._push_rules_stream_id_gen.get_next() as stream_id:
+            event_stream_ordering = self._stream_id_gen.get_current_token()
+
             await self.db_pool.runInteraction(
                 "_set_push_rule_enabled_txn",
                 self._set_push_rule_enabled_txn,
@@ -643,8 +646,9 @@ class PushRuleStore(PushRulesWorkerStore):
                 data={"actions": actions_json},
             )
 
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
+        with self._push_rules_stream_id_gen.get_next() as stream_id:
+            event_stream_ordering = self._stream_id_gen.get_current_token()
+
             await self.db_pool.runInteraction(
                 "set_push_rule_actions",
                 set_push_rule_actions_txn,
@@ -673,11 +677,5 @@ class PushRuleStore(PushRulesWorkerStore):
             self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
         )
 
-    def get_push_rules_stream_token(self):
-        """Get the position of the push rules stream.
-        Returns a pair of a stream id for the push_rules stream and the
-        room stream ordering it corresponds to."""
-        return self._push_rules_stream_id_gen.get_current_token()
-
     def get_max_push_rules_stream_id(self):
-        return self.get_push_rules_stream_token()[0]
+        return self._push_rules_stream_id_gen.get_current_token()
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 8276a755e5..0bf772d4d1 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -16,7 +16,7 @@
 import contextlib
 import threading
 from collections import deque
-from typing import Dict, Set, Tuple
+from typing import Dict, Set
 
 from typing_extensions import Deque
 
@@ -167,72 +167,6 @@ class StreamIdGenerator(object):
         return self.get_current_token()
 
 
-class ChainedIdGenerator(object):
-    """Used to generate new stream ids where the stream must be kept in sync
-    with another stream. It generates pairs of IDs, the first element is an
-    integer ID for this stream, the second element is the ID for the stream
-    that this stream needs to be kept in sync with."""
-
-    def __init__(self, chained_generator, db_conn, table, column):
-        self.chained_generator = chained_generator
-        self._table = table
-        self._lock = threading.Lock()
-        self._current_max = _load_current_id(db_conn, table, column)
-        self._unfinished_ids = deque()  # type: Deque[Tuple[int, int]]
-
-    def get_next(self):
-        """
-        Usage:
-            with stream_id_gen.get_next() as (stream_id, chained_id):
-                # ... persist event ...
-        """
-        with self._lock:
-            self._current_max += 1
-            next_id = self._current_max
-            chained_id = self.chained_generator.get_current_token()
-
-            self._unfinished_ids.append((next_id, chained_id))
-
-        @contextlib.contextmanager
-        def manager():
-            try:
-                yield (next_id, chained_id)
-            finally:
-                with self._lock:
-                    self._unfinished_ids.remove((next_id, chained_id))
-
-        return manager()
-
-    def get_current_token(self):
-        """Returns the maximum stream id such that all stream ids less than or
-        equal to it have been successfully persisted.
-        """
-        with self._lock:
-            if self._unfinished_ids:
-                stream_id, chained_id = self._unfinished_ids[0]
-                return stream_id - 1, chained_id
-
-            return self._current_max, self.chained_generator.get_current_token()
-
-    def advance(self, token: int):
-        """Stub implementation for advancing the token when receiving updates
-        over replication; raises an exception as this instance should be the
-        only source of updates.
-        """
-
-        raise Exception(
-            "Attempted to advance token on source for table %r", self._table
-        )
-
-    def get_current_token_for_writer(self, instance_name: str) -> Tuple[int, int]:
-        """Returns the position of the given writer.
-
-        For streams with single writers this is equivalent to
-        `get_current_token`.
-        """
-        return self.get_current_token()
-
-
 class MultiWriterIdGenerator:
     """An ID generator that tracks a stream that can have multiple writers.