diff --git a/changelog.d/8123.misc b/changelog.d/8123.misc
new file mode 100644
index 0000000000..7245122896
--- /dev/null
+++ b/changelog.d/8123.misc
@@ -0,0 +1 @@
+Remove `ChainedIdGenerator`.
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 590187df46..90d90833f9 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
@@ -21,16 +22,13 @@ from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
- def get_push_rules_stream_token(self):
- return (
- self._push_rules_stream_id_gen.get_current_token(),
- self._stream_id_gen.get_current_token(),
- )
-
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
+ # We assert this for the benefit of mypy
+ assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
+
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(token)
for row in rows:
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 1e92d52165..8c3caf30c9 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -352,7 +352,7 @@ class PushRulesStream(Stream):
)
def _current_token(self, instance_name: str) -> int:
- push_rules_token, _ = self.store.get_push_rules_stream_token()
+ push_rules_token = self.store.get_max_push_rules_stream_id()
return push_rules_token
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index e2df638cc5..e781a3bcf4 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -159,7 +159,7 @@ class PushRuleRestServlet(RestServlet):
return 200, {}
def notify_user(self, user_id):
- stream_id, _ = self.store.get_push_rules_stream_token()
+ stream_id = self.store.get_max_push_rules_stream_id()
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
async def set_rule_attr(self, user_id, spec, val):
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index c2289a9557..a585e54812 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -30,7 +30,7 @@ from synapse.storage.databases.main.pusher import PusherWorkerStore
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import ChainedIdGenerator
+from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -82,9 +82,9 @@ class PushRulesWorkerStore(
super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
- self._push_rules_stream_id_gen = ChainedIdGenerator(
- self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
- ) # type: Union[ChainedIdGenerator, SlavedIdTracker]
+ self._push_rules_stream_id_gen = StreamIdGenerator(
+ db_conn, "push_rules_stream", "stream_id"
+ ) # type: Union[StreamIdGenerator, SlavedIdTracker]
else:
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
@@ -338,8 +338,9 @@ class PushRuleStore(PushRulesWorkerStore):
) -> None:
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
+ with self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
if before or after:
await self.db_pool.runInteraction(
"_add_push_rule_relative_txn",
@@ -559,8 +560,9 @@ class PushRuleStore(PushRulesWorkerStore):
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
)
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
+ with self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
await self.db_pool.runInteraction(
"delete_push_rule",
delete_push_rule_txn,
@@ -569,8 +571,9 @@ class PushRuleStore(PushRulesWorkerStore):
)
async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
+ with self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
await self.db_pool.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
@@ -643,8 +646,9 @@ class PushRuleStore(PushRulesWorkerStore):
data={"actions": actions_json},
)
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
+ with self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
+
await self.db_pool.runInteraction(
"set_push_rule_actions",
set_push_rule_actions_txn,
@@ -673,11 +677,5 @@ class PushRuleStore(PushRulesWorkerStore):
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
- def get_push_rules_stream_token(self):
- """Get the position of the push rules stream.
- Returns a pair of a stream id for the push_rules stream and the
- room stream ordering it corresponds to."""
- return self._push_rules_stream_id_gen.get_current_token()
-
def get_max_push_rules_stream_id(self):
- return self.get_push_rules_stream_token()[0]
+ return self._push_rules_stream_id_gen.get_current_token()
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 8276a755e5..0bf772d4d1 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -16,7 +16,7 @@
import contextlib
import threading
from collections import deque
-from typing import Dict, Set, Tuple
+from typing import Dict, Set
from typing_extensions import Deque
@@ -167,72 +167,6 @@ class StreamIdGenerator(object):
return self.get_current_token()
-class ChainedIdGenerator(object):
- """Used to generate new stream ids where the stream must be kept in sync
- with another stream. It generates pairs of IDs, the first element is an
- integer ID for this stream, the second element is the ID for the stream
- that this stream needs to be kept in sync with."""
-
- def __init__(self, chained_generator, db_conn, table, column):
- self.chained_generator = chained_generator
- self._table = table
- self._lock = threading.Lock()
- self._current_max = _load_current_id(db_conn, table, column)
- self._unfinished_ids = deque() # type: Deque[Tuple[int, int]]
-
- def get_next(self):
- """
- Usage:
- with stream_id_gen.get_next() as (stream_id, chained_id):
- # ... persist event ...
- """
- with self._lock:
- self._current_max += 1
- next_id = self._current_max
- chained_id = self.chained_generator.get_current_token()
-
- self._unfinished_ids.append((next_id, chained_id))
-
- @contextlib.contextmanager
- def manager():
- try:
- yield (next_id, chained_id)
- finally:
- with self._lock:
- self._unfinished_ids.remove((next_id, chained_id))
-
- return manager()
-
- def get_current_token(self):
- """Returns the maximum stream id such that all stream ids less than or
- equal to it have been successfully persisted.
- """
- with self._lock:
- if self._unfinished_ids:
- stream_id, chained_id = self._unfinished_ids[0]
- return stream_id - 1, chained_id
-
- return self._current_max, self.chained_generator.get_current_token()
-
- def advance(self, token: int):
- """Stub implementation for advancing the token when receiving updates
- over replication; raises an exception as this instance should be the
- only source of updates.
- """
-
- raise Exception(
- "Attempted to advance token on source for table %r", self._table
- )
-
- def get_current_token_for_writer(self, instance_name: str) -> Tuple[int, int]:
- """Returns the position of the given writer.
-
- For streams with single writers this is equivalent to
- `get_current_token`.
- """
- return self.get_current_token()
-
-
class MultiWriterIdGenerator:
"""An ID generator that tracks a stream that can have multiple writers.
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 393e34b9fb..7ab46f42bf 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -39,7 +39,7 @@ class EventSources(object):
self.store = hs.get_datastore()
def get_current_token(self) -> StreamToken:
- push_rules_key, _ = self.store.get_push_rules_stream_token()
+ push_rules_key = self.store.get_max_push_rules_stream_id()
to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
groups_key = self.store.get_group_stream_token()
|