summary refs log tree commit diff
path: root/synapse/storage/databases/main/push_rule.py
diff options
context:
space:
mode:
authorErik Johnston <erikj@element.io>2024-03-28 15:44:07 +0000
committerGitHub <noreply@github.com>2024-03-28 15:44:07 +0000
commitea6bfae0fca5303bcf2e474694d6a388ef3b6a90 (patch)
tree3e276ad9730b0af63de01925e96b5fade9d13d00 /synapse/storage/databases/main/push_rule.py
parentFixup changelog (diff)
downloadsynapse-ea6bfae0fca5303bcf2e474694d6a388ef3b6a90.tar.xz
Add support for moving `/push_rules` off of main process (#17037)
Diffstat (limited to 'synapse/storage/databases/main/push_rule.py')
-rw-r--r--synapse/storage/databases/main/push_rule.py67
1 files changed, 41 insertions, 26 deletions
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 91beca6ffc..ed734f03ac 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -53,11 +53,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import (
-    AbstractStreamIdGenerator,
-    IdGenerator,
-    StreamIdGenerator,
-)
+from synapse.storage.util.id_generators import IdGenerator, StreamIdGenerator
 from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
 from synapse.types import JsonDict
 from synapse.util import json_encoder, unwrapFirstError
@@ -130,6 +126,8 @@ class PushRulesWorkerStore(
     `get_max_push_rules_stream_id` which can be called in the initializer.
     """
 
+    _push_rules_stream_id_gen: StreamIdGenerator
+
     def __init__(
         self,
         database: DatabasePool,
@@ -138,6 +136,8 @@ class PushRulesWorkerStore(
     ):
         super().__init__(database, db_conn, hs)
 
+        self._is_push_writer = hs.get_instance_name() in hs.config.worker.writers.push
+
         # In the worker store this is an ID tracker which we overwrite in the non-worker
         # class below that is used on the main process.
         self._push_rules_stream_id_gen = StreamIdGenerator(
@@ -145,7 +145,7 @@ class PushRulesWorkerStore(
             hs.get_replication_notifier(),
             "push_rules_stream",
             "stream_id",
-            is_writer=hs.config.worker.worker_app is None,
+            is_writer=self._is_push_writer,
         )
 
         push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
@@ -162,6 +162,9 @@ class PushRulesWorkerStore(
             prefilled_cache=push_rules_prefill,
         )
 
+        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
+        self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
+
     def get_max_push_rules_stream_id(self) -> int:
         """Get the position of the push rules stream.
 
@@ -383,23 +386,6 @@ class PushRulesWorkerStore(
             "get_all_push_rule_updates", get_all_push_rule_updates_txn
         )
 
-
-class PushRuleStore(PushRulesWorkerStore):
-    # Because we have write access, this will be a StreamIdGenerator
-    # (see PushRulesWorkerStore.__init__)
-    _push_rules_stream_id_gen: AbstractStreamIdGenerator
-
-    def __init__(
-        self,
-        database: DatabasePool,
-        db_conn: LoggingDatabaseConnection,
-        hs: "HomeServer",
-    ):
-        super().__init__(database, db_conn, hs)
-
-        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
-        self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
-
     async def add_push_rule(
         self,
         user_id: str,
@@ -410,6 +396,9 @@ class PushRuleStore(PushRulesWorkerStore):
         before: Optional[str] = None,
         after: Optional[str] = None,
     ) -> None:
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
+
         conditions_json = json_encoder.encode(conditions)
         actions_json = json_encoder.encode(actions)
         async with self._push_rules_stream_id_gen.get_next() as stream_id:
@@ -455,6 +444,9 @@ class PushRuleStore(PushRulesWorkerStore):
         before: str,
         after: str,
     ) -> None:
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
+
         relative_to_rule = before or after
 
         sql = """
@@ -524,6 +516,9 @@ class PushRuleStore(PushRulesWorkerStore):
         conditions_json: str,
         actions_json: str,
     ) -> None:
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
+
         if isinstance(self.database_engine, PostgresEngine):
             # Postgres doesn't do FOR UPDATE on aggregate functions, so select the rows first
             # then re-select the count/max below.
@@ -575,6 +570,9 @@ class PushRuleStore(PushRulesWorkerStore):
         actions_json: str,
         update_stream: bool = True,
     ) -> None:
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
+
         """Specialised version of simple_upsert_txn that picks a push_rule_id
         using the _push_rule_id_gen if it needs to insert the rule. It assumes
         that the "push_rules" table is locked"""
@@ -653,6 +651,8 @@ class PushRuleStore(PushRulesWorkerStore):
             user_id: The matrix ID of the push rule owner
             rule_id: The rule_id of the rule to be deleted
         """
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
 
         def delete_push_rule_txn(
             txn: LoggingTransaction,
@@ -704,6 +704,9 @@ class PushRuleStore(PushRulesWorkerStore):
         Raises:
             RuleNotFoundException if the rule does not exist.
         """
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
+
         async with self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
             await self.db_pool.runInteraction(
@@ -727,6 +730,9 @@ class PushRuleStore(PushRulesWorkerStore):
         enabled: bool,
         is_default_rule: bool,
     ) -> None:
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
+
         new_id = self._push_rules_enable_id_gen.get_next()
 
         if not is_default_rule:
@@ -796,6 +802,9 @@ class PushRuleStore(PushRulesWorkerStore):
         Raises:
             RuleNotFoundException if the rule does not exist.
         """
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
+
         actions_json = json_encoder.encode(actions)
 
         def set_push_rule_actions_txn(
@@ -865,6 +874,9 @@ class PushRuleStore(PushRulesWorkerStore):
         op: str,
         data: Optional[JsonDict] = None,
     ) -> None:
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
+
         values = {
             "stream_id": stream_id,
             "event_stream_ordering": event_stream_ordering,
@@ -882,9 +894,6 @@ class PushRuleStore(PushRulesWorkerStore):
             self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
         )
 
-    def get_max_push_rules_stream_id(self) -> int:
-        return self._push_rules_stream_id_gen.get_current_token()
-
     async def copy_push_rule_from_room_to_room(
         self, new_room_id: str, user_id: str, rule: PushRule
     ) -> None:
@@ -895,6 +904,9 @@ class PushRuleStore(PushRulesWorkerStore):
             user_id : ID of user the push rule belongs to.
             rule: A push rule.
         """
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
+
         # Create new rule id
         rule_id_scope = "/".join(rule.rule_id.split("/")[:-1])
         new_rule_id = rule_id_scope + "/" + new_room_id
@@ -930,6 +942,9 @@ class PushRuleStore(PushRulesWorkerStore):
             new_room_id: ID of the new room.
             user_id: ID of user to copy push rules for.
         """
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
+
         # Retrieve push rules for this user
         user_push_rules = await self.get_push_rules_for_user(user_id)