diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 91beca6ffc..ed734f03ac 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -53,11 +53,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import (
- AbstractStreamIdGenerator,
- IdGenerator,
- StreamIdGenerator,
-)
+from synapse.storage.util.id_generators import IdGenerator, StreamIdGenerator
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
from synapse.types import JsonDict
from synapse.util import json_encoder, unwrapFirstError
@@ -130,6 +126,8 @@ class PushRulesWorkerStore(
`get_max_push_rules_stream_id` which can be called in the initializer.
"""
+ _push_rules_stream_id_gen: StreamIdGenerator
+
def __init__(
self,
database: DatabasePool,
@@ -138,6 +136,8 @@ class PushRulesWorkerStore(
):
super().__init__(database, db_conn, hs)
+ self._is_push_writer = hs.get_instance_name() in hs.config.worker.writers.push
+
# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
self._push_rules_stream_id_gen = StreamIdGenerator(
@@ -145,7 +145,7 @@ class PushRulesWorkerStore(
hs.get_replication_notifier(),
"push_rules_stream",
"stream_id",
- is_writer=hs.config.worker.worker_app is None,
+ is_writer=self._is_push_writer,
)
push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
@@ -162,6 +162,9 @@ class PushRulesWorkerStore(
prefilled_cache=push_rules_prefill,
)
+ self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
+ self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
+
def get_max_push_rules_stream_id(self) -> int:
"""Get the position of the push rules stream.
@@ -383,23 +386,6 @@ class PushRulesWorkerStore(
"get_all_push_rule_updates", get_all_push_rule_updates_txn
)
-
-class PushRuleStore(PushRulesWorkerStore):
- # Because we have write access, this will be a StreamIdGenerator
- # (see PushRulesWorkerStore.__init__)
- _push_rules_stream_id_gen: AbstractStreamIdGenerator
-
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
- super().__init__(database, db_conn, hs)
-
- self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
- self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
-
async def add_push_rule(
self,
user_id: str,
@@ -410,6 +396,9 @@ class PushRuleStore(PushRulesWorkerStore):
before: Optional[str] = None,
after: Optional[str] = None,
) -> None:
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
+
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
async with self._push_rules_stream_id_gen.get_next() as stream_id:
@@ -455,6 +444,9 @@ class PushRuleStore(PushRulesWorkerStore):
before: str,
after: str,
) -> None:
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
+
relative_to_rule = before or after
sql = """
@@ -524,6 +516,9 @@ class PushRuleStore(PushRulesWorkerStore):
conditions_json: str,
actions_json: str,
) -> None:
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
+
if isinstance(self.database_engine, PostgresEngine):
# Postgres doesn't do FOR UPDATE on aggregate functions, so select the rows first
# then re-select the count/max below.
@@ -575,6 +570,9 @@ class PushRuleStore(PushRulesWorkerStore):
actions_json: str,
update_stream: bool = True,
) -> None:
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
+
"""Specialised version of simple_upsert_txn that picks a push_rule_id
using the _push_rule_id_gen if it needs to insert the rule. It assumes
that the "push_rules" table is locked"""
@@ -653,6 +651,8 @@ class PushRuleStore(PushRulesWorkerStore):
user_id: The matrix ID of the push rule owner
rule_id: The rule_id of the rule to be deleted
"""
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
def delete_push_rule_txn(
txn: LoggingTransaction,
@@ -704,6 +704,9 @@ class PushRuleStore(PushRulesWorkerStore):
Raises:
RuleNotFoundException if the rule does not exist.
"""
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
+
async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
@@ -727,6 +730,9 @@ class PushRuleStore(PushRulesWorkerStore):
enabled: bool,
is_default_rule: bool,
) -> None:
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
+
new_id = self._push_rules_enable_id_gen.get_next()
if not is_default_rule:
@@ -796,6 +802,9 @@ class PushRuleStore(PushRulesWorkerStore):
Raises:
RuleNotFoundException if the rule does not exist.
"""
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
+
actions_json = json_encoder.encode(actions)
def set_push_rule_actions_txn(
@@ -865,6 +874,9 @@ class PushRuleStore(PushRulesWorkerStore):
op: str,
data: Optional[JsonDict] = None,
) -> None:
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
+
values = {
"stream_id": stream_id,
"event_stream_ordering": event_stream_ordering,
@@ -882,9 +894,6 @@ class PushRuleStore(PushRulesWorkerStore):
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
- def get_max_push_rules_stream_id(self) -> int:
- return self._push_rules_stream_id_gen.get_current_token()
-
async def copy_push_rule_from_room_to_room(
self, new_room_id: str, user_id: str, rule: PushRule
) -> None:
@@ -895,6 +904,9 @@ class PushRuleStore(PushRulesWorkerStore):
user_id : ID of user the push rule belongs to.
rule: A push rule.
"""
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
+
# Create new rule id
rule_id_scope = "/".join(rule.rule_id.split("/")[:-1])
new_rule_id = rule_id_scope + "/" + new_room_id
@@ -930,6 +942,9 @@ class PushRuleStore(PushRulesWorkerStore):
new_room_id: ID of the new room.
user_id: ID of user to copy push rules for.
"""
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
+
# Retrieve push rules for this user
user_push_rules = await self.get_push_rules_for_user(user_id)
|