summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/replication/slave/storage/push_rule.py10
-rw-r--r--synapse/replication/tcp/streams/_base.py2
-rw-r--r--synapse/rest/client/v1/push_rule.py2
-rw-r--r--synapse/storage/databases/main/push_rule.py36
-rw-r--r--synapse/storage/util/id_generators.py68
-rw-r--r--synapse/streams/events.py2
6 files changed, 25 insertions, 95 deletions
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 590187df46..90d90833f9 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import PushRulesStream
 from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
 
@@ -21,16 +22,13 @@ from .events import SlavedEventStore
 
 
 class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
-    def get_push_rules_stream_token(self):
-        return (
-            self._push_rules_stream_id_gen.get_current_token(),
-            self._stream_id_gen.get_current_token(),
-        )
-
     def get_max_push_rules_stream_id(self):
         return self._push_rules_stream_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
+        # We assert this for the benefit of mypy
+        assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
+
         if stream_name == PushRulesStream.NAME:
             self._push_rules_stream_id_gen.advance(token)
             for row in rows:
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 1e92d52165..8c3caf30c9 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -352,7 +352,7 @@ class PushRulesStream(Stream):
         )
 
     def _current_token(self, instance_name: str) -> int:
-        push_rules_token, _ = self.store.get_push_rules_stream_token()
+        push_rules_token = self.store.get_max_push_rules_stream_id()
         return push_rules_token
 
 
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index e2df638cc5..e781a3bcf4 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -159,7 +159,7 @@ class PushRuleRestServlet(RestServlet):
         return 200, {}
 
     def notify_user(self, user_id):
-        stream_id, _ = self.store.get_push_rules_stream_token()
+        stream_id = self.store.get_max_push_rules_stream_id()
         self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
 
     async def set_rule_attr(self, user_id, spec, val):
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index c2289a9557..a585e54812 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -30,7 +30,7 @@ from synapse.storage.databases.main.pusher import PusherWorkerStore
 from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import ChainedIdGenerator
+from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -82,9 +82,9 @@ class PushRulesWorkerStore(
         super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
 
         if hs.config.worker.worker_app is None:
-            self._push_rules_stream_id_gen = ChainedIdGenerator(
-                self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
-            )  # type: Union[ChainedIdGenerator, SlavedIdTracker]
+            self._push_rules_stream_id_gen = StreamIdGenerator(
+                db_conn, "push_rules_stream", "stream_id"
+            )  # type: Union[StreamIdGenerator, SlavedIdTracker]
         else:
             self._push_rules_stream_id_gen = SlavedIdTracker(
                 db_conn, "push_rules_stream", "stream_id"
@@ -338,8 +338,9 @@ class PushRuleStore(PushRulesWorkerStore):
     ) -> None:
         conditions_json = json_encoder.encode(conditions)
         actions_json = json_encoder.encode(actions)
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
+        with self._push_rules_stream_id_gen.get_next() as stream_id:
+            event_stream_ordering = self._stream_id_gen.get_current_token()
+
             if before or after:
                 await self.db_pool.runInteraction(
                     "_add_push_rule_relative_txn",
@@ -559,8 +560,9 @@ class PushRuleStore(PushRulesWorkerStore):
                 txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
             )
 
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
+        with self._push_rules_stream_id_gen.get_next() as stream_id:
+            event_stream_ordering = self._stream_id_gen.get_current_token()
+
             await self.db_pool.runInteraction(
                 "delete_push_rule",
                 delete_push_rule_txn,
@@ -569,8 +571,9 @@ class PushRuleStore(PushRulesWorkerStore):
             )
 
     async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
+        with self._push_rules_stream_id_gen.get_next() as stream_id:
+            event_stream_ordering = self._stream_id_gen.get_current_token()
+
             await self.db_pool.runInteraction(
                 "_set_push_rule_enabled_txn",
                 self._set_push_rule_enabled_txn,
@@ -643,8 +646,9 @@ class PushRuleStore(PushRulesWorkerStore):
                 data={"actions": actions_json},
             )
 
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
+        with self._push_rules_stream_id_gen.get_next() as stream_id:
+            event_stream_ordering = self._stream_id_gen.get_current_token()
+
             await self.db_pool.runInteraction(
                 "set_push_rule_actions",
                 set_push_rule_actions_txn,
@@ -673,11 +677,5 @@ class PushRuleStore(PushRulesWorkerStore):
             self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
         )
 
-    def get_push_rules_stream_token(self):
-        """Get the position of the push rules stream.
-        Returns a pair of a stream id for the push_rules stream and the
-        room stream ordering it corresponds to."""
-        return self._push_rules_stream_id_gen.get_current_token()
-
     def get_max_push_rules_stream_id(self):
-        return self.get_push_rules_stream_token()[0]
+        return self._push_rules_stream_id_gen.get_current_token()
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 8276a755e5..0bf772d4d1 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -16,7 +16,7 @@
 import contextlib
 import threading
 from collections import deque
-from typing import Dict, Set, Tuple
+from typing import Dict, Set
 
 from typing_extensions import Deque
 
@@ -167,72 +167,6 @@ class StreamIdGenerator(object):
         return self.get_current_token()
 
 
-class ChainedIdGenerator(object):
-    """Used to generate new stream ids where the stream must be kept in sync
-    with another stream. It generates pairs of IDs, the first element is an
-    integer ID for this stream, the second element is the ID for the stream
-    that this stream needs to be kept in sync with."""
-
-    def __init__(self, chained_generator, db_conn, table, column):
-        self.chained_generator = chained_generator
-        self._table = table
-        self._lock = threading.Lock()
-        self._current_max = _load_current_id(db_conn, table, column)
-        self._unfinished_ids = deque()  # type: Deque[Tuple[int, int]]
-
-    def get_next(self):
-        """
-        Usage:
-            with stream_id_gen.get_next() as (stream_id, chained_id):
-                # ... persist event ...
-        """
-        with self._lock:
-            self._current_max += 1
-            next_id = self._current_max
-            chained_id = self.chained_generator.get_current_token()
-
-            self._unfinished_ids.append((next_id, chained_id))
-
-        @contextlib.contextmanager
-        def manager():
-            try:
-                yield (next_id, chained_id)
-            finally:
-                with self._lock:
-                    self._unfinished_ids.remove((next_id, chained_id))
-
-        return manager()
-
-    def get_current_token(self):
-        """Returns the maximum stream id such that all stream ids less than or
-        equal to it have been successfully persisted.
-        """
-        with self._lock:
-            if self._unfinished_ids:
-                stream_id, chained_id = self._unfinished_ids[0]
-                return stream_id - 1, chained_id
-
-            return self._current_max, self.chained_generator.get_current_token()
-
-    def advance(self, token: int):
-        """Stub implementation for advancing the token when receiving updates
-        over replication; raises an exception as this instance should be the
-        only source of updates.
-        """
-
-        raise Exception(
-            "Attempted to advance token on source for table %r", self._table
-        )
-
-    def get_current_token_for_writer(self, instance_name: str) -> Tuple[int, int]:
-        """Returns the position of the given writer.
-
-        For streams with single writers this is equivalent to
-        `get_current_token`.
-        """
-        return self.get_current_token()
-
-
 class MultiWriterIdGenerator:
     """An ID generator that tracks a stream that can have multiple writers.
 
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 393e34b9fb..7ab46f42bf 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -39,7 +39,7 @@ class EventSources(object):
         self.store = hs.get_datastore()
 
     def get_current_token(self) -> StreamToken:
-        push_rules_key, _ = self.store.get_push_rules_stream_token()
+        push_rules_key = self.store.get_max_push_rules_stream_id()
         to_device_key = self.store.get_to_device_stream_token()
         device_list_key = self.store.get_device_stream_token()
         groups_key = self.store.get_group_stream_token()