diff options
author | Erik Johnston <erikj@element.io> | 2024-03-28 15:44:07 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-28 15:44:07 +0000 |
commit | ea6bfae0fca5303bcf2e474694d6a388ef3b6a90 (patch) | |
tree | 3e276ad9730b0af63de01925e96b5fade9d13d00 /synapse/storage/databases/main/push_rule.py | |
parent | Fixup changelog (diff) | |
download | synapse-ea6bfae0fca5303bcf2e474694d6a388ef3b6a90.tar.xz |
Add support for moving `/push_rules` off of main process (#17037)
Diffstat (limited to 'synapse/storage/databases/main/push_rule.py')
-rw-r--r-- | synapse/storage/databases/main/push_rule.py | 67 |
1 files changed, 41 insertions, 26 deletions
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 91beca6ffc..ed734f03ac 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -53,11 +53,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException -from synapse.storage.util.id_generators import ( - AbstractStreamIdGenerator, - IdGenerator, - StreamIdGenerator, -) +from synapse.storage.util.id_generators import IdGenerator, StreamIdGenerator from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules from synapse.types import JsonDict from synapse.util import json_encoder, unwrapFirstError @@ -130,6 +126,8 @@ class PushRulesWorkerStore( `get_max_push_rules_stream_id` which can be called in the initializer. """ + _push_rules_stream_id_gen: StreamIdGenerator + def __init__( self, database: DatabasePool, @@ -138,6 +136,8 @@ class PushRulesWorkerStore( ): super().__init__(database, db_conn, hs) + self._is_push_writer = hs.get_instance_name() in hs.config.worker.writers.push + # In the worker store this is an ID tracker which we overwrite in the non-worker # class below that is used on the main process. self._push_rules_stream_id_gen = StreamIdGenerator( @@ -145,7 +145,7 @@ class PushRulesWorkerStore( hs.get_replication_notifier(), "push_rules_stream", "stream_id", - is_writer=hs.config.worker.worker_app is None, + is_writer=self._is_push_writer, ) push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( @@ -162,6 +162,9 @@ class PushRulesWorkerStore( prefilled_cache=push_rules_prefill, ) + self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") + self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") + def get_max_push_rules_stream_id(self) -> int: """Get the position of the push rules stream. @@ -383,23 +386,6 @@ class PushRulesWorkerStore( "get_all_push_rule_updates", get_all_push_rule_updates_txn ) - -class PushRuleStore(PushRulesWorkerStore): - # Because we have write access, this will be a StreamIdGenerator - # (see PushRulesWorkerStore.__init__) - _push_rules_stream_id_gen: AbstractStreamIdGenerator - - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - super().__init__(database, db_conn, hs) - - self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") - self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") - async def add_push_rule( self, user_id: str, @@ -410,6 +396,9 @@ class PushRuleStore(PushRulesWorkerStore): before: Optional[str] = None, after: Optional[str] = None, ) -> None: + if not self._is_push_writer: + raise Exception("Not a push writer") + conditions_json = json_encoder.encode(conditions) actions_json = json_encoder.encode(actions) async with self._push_rules_stream_id_gen.get_next() as stream_id: @@ -455,6 +444,9 @@ class PushRuleStore(PushRulesWorkerStore): before: str, after: str, ) -> None: + if not self._is_push_writer: + raise Exception("Not a push writer") + relative_to_rule = before or after sql = """ @@ -524,6 +516,9 @@ class PushRuleStore(PushRulesWorkerStore): conditions_json: str, actions_json: str, ) -> None: + if not self._is_push_writer: + raise Exception("Not a push writer") + if isinstance(self.database_engine, PostgresEngine): # Postgres doesn't do FOR UPDATE on aggregate functions, so select the rows first # then re-select the count/max below. @@ -575,6 +570,9 @@ class PushRuleStore(PushRulesWorkerStore): actions_json: str, update_stream: bool = True, ) -> None: + if not self._is_push_writer: + raise Exception("Not a push writer") + """Specialised version of simple_upsert_txn that picks a push_rule_id using the _push_rule_id_gen if it needs to insert the rule. It assumes that the "push_rules" table is locked""" @@ -653,6 +651,8 @@ class PushRuleStore(PushRulesWorkerStore): user_id: The matrix ID of the push rule owner rule_id: The rule_id of the rule to be deleted """ + if not self._is_push_writer: + raise Exception("Not a push writer") def delete_push_rule_txn( txn: LoggingTransaction, @@ -704,6 +704,9 @@ class PushRuleStore(PushRulesWorkerStore): Raises: RuleNotFoundException if the rule does not exist. """ + if not self._is_push_writer: + raise Exception("Not a push writer") + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -727,6 +730,9 @@ class PushRuleStore(PushRulesWorkerStore): enabled: bool, is_default_rule: bool, ) -> None: + if not self._is_push_writer: + raise Exception("Not a push writer") + new_id = self._push_rules_enable_id_gen.get_next() if not is_default_rule: @@ -796,6 +802,9 @@ class PushRuleStore(PushRulesWorkerStore): Raises: RuleNotFoundException if the rule does not exist. """ + if not self._is_push_writer: + raise Exception("Not a push writer") + actions_json = json_encoder.encode(actions) def set_push_rule_actions_txn( @@ -865,6 +874,9 @@ class PushRuleStore(PushRulesWorkerStore): op: str, data: Optional[JsonDict] = None, ) -> None: + if not self._is_push_writer: + raise Exception("Not a push writer") + values = { "stream_id": stream_id, "event_stream_ordering": event_stream_ordering, @@ -882,9 +894,6 @@ class PushRuleStore(PushRulesWorkerStore): self.push_rules_stream_cache.entity_has_changed, user_id, stream_id ) - def get_max_push_rules_stream_id(self) -> int: - return self._push_rules_stream_id_gen.get_current_token() - async def copy_push_rule_from_room_to_room( self, new_room_id: str, user_id: str, rule: PushRule ) -> None: @@ -895,6 +904,9 @@ class PushRuleStore(PushRulesWorkerStore): user_id : ID of user the push rule belongs to. rule: A push rule. """ + if not self._is_push_writer: + raise Exception("Not a push writer") + # Create new rule id rule_id_scope = "/".join(rule.rule_id.split("/")[:-1]) new_rule_id = rule_id_scope + "/" + new_room_id @@ -930,6 +942,9 @@ class PushRuleStore(PushRulesWorkerStore): new_room_id: ID of the new room. user_id: ID of user to copy push rules for. """ + if not self._is_push_writer: + raise Exception("Not a push writer") + # Retrieve push rules for this user user_push_rules = await self.get_push_rules_for_user(user_id) |