diff --git a/changelog.d/17037.feature b/changelog.d/17037.feature
new file mode 100644
index 0000000000..498221e19e
--- /dev/null
+++ b/changelog.d/17037.feature
@@ -0,0 +1 @@
+Add support for moving `/pushrules` off of main process.
diff --git a/changelog.d/17038.feature b/changelog.d/17038.feature
new file mode 100644
index 0000000000..498221e19e
--- /dev/null
+++ b/changelog.d/17038.feature
@@ -0,0 +1 @@
+Add support for moving `/pushrules` off of main process.
diff --git a/changelog.d/17044.misc b/changelog.d/17044.misc
new file mode 100644
index 0000000000..a1439752d3
--- /dev/null
+++ b/changelog.d/17044.misc
@@ -0,0 +1 @@
+Refactor auth chain fetching to reduce duplication.
diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py
index 3917d9ae7e..77534a4f4f 100755
--- a/docker/configure_workers_and_start.py
+++ b/docker/configure_workers_and_start.py
@@ -310,6 +310,13 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
"shared_extra_conf": {},
"worker_extra_conf": "",
},
+ "push_rules": {
+ "app": "synapse.app.generic_worker",
+ "listener_resources": ["client", "replication"],
+ "endpoint_patterns": ["^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/"],
+ "shared_extra_conf": {},
+ "worker_extra_conf": "",
+ },
}
# Templates for sections that may be inserted multiple times in config files
@@ -401,6 +408,7 @@ def add_worker_roles_to_shared_config(
"receipts",
"to_device",
"typing",
+ "push_rules",
]
# Worker-type specific sharding config. Now a single worker can fulfill multiple
diff --git a/docs/workers.md b/docs/workers.md
index d19f1a9dea..ab9c1db86b 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -532,6 +532,13 @@ the stream writer for the `presence` stream:
^/_matrix/client/(api/v1|r0|v3|unstable)/presence/
+##### The `push_rules` stream
+
+The following endpoints should be routed directly to the worker configured as
+the stream writer for the `push` stream:
+
+ ^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/
+
#### Restrict outbound federation traffic to a specific set of workers
The
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index a533cad5ae..15507372a4 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -60,7 +60,7 @@ from synapse.logging.context import (
)
from synapse.notifier import ReplicationNotifier
from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
-from synapse.storage.databases.main import FilteringWorkerStore, PushRuleStore
+from synapse.storage.databases.main import FilteringWorkerStore
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
@@ -77,10 +77,8 @@ from synapse.storage.databases.main.media_repository import (
)
from synapse.storage.databases.main.presence import PresenceBackgroundUpdateStore
from synapse.storage.databases.main.profile import ProfileWorkerStore
-from synapse.storage.databases.main.pusher import (
- PusherBackgroundUpdatesStore,
- PusherWorkerStore,
-)
+from synapse.storage.databases.main.push_rule import PusherWorkerStore
+from synapse.storage.databases.main.pusher import PusherBackgroundUpdatesStore
from synapse.storage.databases.main.receipts import ReceiptsBackgroundUpdateStore
from synapse.storage.databases.main.registration import (
RegistrationBackgroundUpdateStore,
@@ -245,7 +243,6 @@ class Store(
AccountDataWorkerStore,
FilteringWorkerStore,
ProfileWorkerStore,
- PushRuleStore,
PusherWorkerStore,
PusherBackgroundUpdatesStore,
PresenceBackgroundUpdateStore,
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index e9c67807e5..7ecf349e4a 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -156,6 +156,8 @@ class WriterLocations:
can only be a single instance.
presence: The instances that write to the presence stream. Currently
can only be a single instance.
+ push_rules: The instances that write to the push stream. Currently
+ can only be a single instance.
"""
events: List[str] = attr.ib(
@@ -182,6 +184,10 @@ class WriterLocations:
default=["master"],
converter=_instance_to_list_converter,
)
+ push_rules: List[str] = attr.ib(
+ default=["master"],
+ converter=_instance_to_list_converter,
+ )
@attr.s(auto_attribs=True)
@@ -341,6 +347,7 @@ class WorkerConfig(Config):
"account_data",
"receipts",
"presence",
+ "push_rules",
):
instances = _instance_to_list_converter(getattr(self.writers, stream))
for instance in instances:
@@ -378,6 +385,11 @@ class WorkerConfig(Config):
"Must only specify one instance to handle `presence` messages."
)
+ if len(self.writers.push_rules) != 1:
+ raise ConfigError(
+ "Must only specify one instance to handle `push` messages."
+ )
+
self.events_shard_config = RoutableShardedWorkerHandlingConfig(
self.writers.events
)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 9e9f6cd062..601d37341b 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -51,6 +51,7 @@ from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging import opentracing
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.http.push import ReplicationCopyPusherRestServlet
from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.types import (
JsonDict,
@@ -181,6 +182,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
hs.config.server.forgotten_room_retention_period
)
+ self._is_push_writer = (
+ hs.get_instance_name() in hs.config.worker.writers.push_rules
+ )
+ self._push_writer = hs.config.worker.writers.push_rules[0]
+ self._copy_push_client = ReplicationCopyPusherRestServlet.make_client(hs)
+
def _on_user_joined_room(self, event_id: str, room_id: str) -> None:
"""Notify the rate limiter that a room join has occurred.
@@ -1301,9 +1308,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
old_room_id, new_room_id, user_id
)
# Copy over push rules
- await self.store.copy_push_rules_from_room_to_room_for_user(
- old_room_id, new_room_id, user_id
- )
+ if self._is_push_writer:
+ await self.store.copy_push_rules_from_room_to_room_for_user(
+ old_room_id, new_room_id, user_id
+ )
+ else:
+ await self._copy_push_client(
+ instance_name=self._push_writer,
+ user_id=user_id,
+ old_room_id=old_room_id,
+ new_room_id=new_room_id,
+ )
except Exception:
logger.exception(
"Error copying tags and/or push rules from rooms %s to %s for user %s. "
diff --git a/synapse/replication/http/push.py b/synapse/replication/http/push.py
index 8e5641707a..de07e75b46 100644
--- a/synapse/replication/http/push.py
+++ b/synapse/replication/http/push.py
@@ -77,5 +77,46 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
return 200, {}
+class ReplicationCopyPusherRestServlet(ReplicationEndpoint):
+ """Copies push rules from an old room to new room.
+
+ Request format:
+
+ POST /_synapse/replication/copy_push_rules/:user_id/:old_room_id/:new_room_id
+
+ {}
+
+ """
+
+ NAME = "copy_push_rules"
+ PATH_ARGS = ("user_id", "old_room_id", "new_room_id")
+ CACHE = False
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ self._store = hs.get_datastores().main
+
+ @staticmethod
+ async def _serialize_payload(user_id: str, old_room_id: str, new_room_id: str) -> JsonDict: # type: ignore[override]
+ return {}
+
+ async def _handle_request( # type: ignore[override]
+ self,
+ request: Request,
+ content: JsonDict,
+ user_id: str,
+ old_room_id: str,
+ new_room_id: str,
+ ) -> Tuple[int, JsonDict]:
+
+ await self._store.copy_push_rules_from_room_to_room_for_user(
+ old_room_id, new_room_id, user_id
+ )
+
+ return 200, {}
+
+
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationRemovePusherRestServlet(hs).register(http_server)
+ ReplicationCopyPusherRestServlet(hs).register(http_server)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index ecc12c0b28..72a42cb6cc 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -66,6 +66,7 @@ from synapse.replication.tcp.streams import (
FederationStream,
PresenceFederationStream,
PresenceStream,
+ PushRulesStream,
ReceiptsStream,
Stream,
ToDeviceStream,
@@ -178,6 +179,12 @@ class ReplicationCommandHandler:
continue
+ if isinstance(stream, PushRulesStream):
+ if hs.get_instance_name() in hs.config.worker.writers.push_rules:
+ self._streams_to_replicate.append(stream)
+
+ continue
+
# Only add any other streams if we're on master.
if hs.config.worker.worker_app is not None:
continue
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index 7d58611abb..af042504c9 100644
--- a/synapse/rest/client/push_rule.py
+++ b/synapse/rest/client/push_rule.py
@@ -59,12 +59,14 @@ class PushRuleRestServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self.notifier = hs.get_notifier()
- self._is_worker = hs.config.worker.worker_app is not None
+ self._is_push_worker = (
+ hs.get_instance_name() in hs.config.worker.writers.push_rules
+ )
self._push_rules_handler = hs.get_push_rules_handler()
self._push_rule_linearizer = Linearizer(name="push_rules")
async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
- if self._is_worker:
+ if not self._is_push_worker:
raise Exception("Cannot handle PUT /push_rules on worker")
requester = await self.auth.get_user_by_req(request)
@@ -137,7 +139,7 @@ class PushRuleRestServlet(RestServlet):
async def on_DELETE(
self, request: SynapseRequest, path: str
) -> Tuple[int, JsonDict]:
- if self._is_worker:
+ if not self._is_push_worker:
raise Exception("Cannot handle DELETE /push_rules on worker")
requester = await self.auth.get_user_by_req(request)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index bf779587d9..586e84f2a4 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -63,7 +63,7 @@ from .openid import OpenIdStore
from .presence import PresenceStore
from .profile import ProfileStore
from .purge_events import PurgeEventsStore
-from .push_rule import PushRuleStore
+from .push_rule import PushRulesWorkerStore
from .pusher import PusherStore
from .receipts import ReceiptsStore
from .registration import RegistrationStore
@@ -130,7 +130,6 @@ class DataStore(
RejectionsStore,
FilteringWorkerStore,
PusherStore,
- PushRuleStore,
ApplicationServiceTransactionStore,
EventPushActionsStore,
ServerMetricsStore,
@@ -140,6 +139,7 @@ class DataStore(
SearchStore,
TagsStore,
AccountDataStore,
+ PushRulesWorkerStore,
StreamWorkerStore,
OpenIdStore,
ClientIpWorkerStore,
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 846c3f363a..fb132ef090 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -27,6 +27,7 @@ from typing import (
Collection,
Dict,
FrozenSet,
+ Generator,
Iterable,
List,
Optional,
@@ -279,64 +280,16 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Now we look up all links for the chains we have, adding chains that
# are reachable from any event.
- #
- # This query is structured to first get all chain IDs reachable, and
- # then pull out all links from those chains. This does pull out more
- # rows than is strictly necessary, however there isn't a way of
- # structuring the recursive part of query to pull out the links without
- # also returning large quantities of redundant data (which can make it a
- # lot slower).
- sql = """
- WITH RECURSIVE links(chain_id) AS (
- SELECT
- DISTINCT origin_chain_id
- FROM event_auth_chain_links WHERE %s
- UNION
- SELECT
- target_chain_id
- FROM event_auth_chain_links
- INNER JOIN links ON (chain_id = origin_chain_id)
- )
- SELECT
- origin_chain_id, origin_sequence_number,
- target_chain_id, target_sequence_number
- FROM links
- INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
- """
# A map from chain ID to max sequence number *reachable* from any event ID.
chains: Dict[int, int] = {}
-
- # Add all linked chains reachable from initial set of chains.
- chains_to_fetch = set(event_chains.keys())
- while chains_to_fetch:
- batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
- chains_to_fetch.difference_update(batch2)
- clause, args = make_in_list_sql_clause(
- txn.database_engine, "origin_chain_id", batch2
- )
- txn.execute(sql % (clause,), args)
-
- links: Dict[int, List[Tuple[int, int, int]]] = {}
-
- for (
- origin_chain_id,
- origin_sequence_number,
- target_chain_id,
- target_sequence_number,
- ) in txn:
- links.setdefault(origin_chain_id, []).append(
- (origin_sequence_number, target_chain_id, target_sequence_number)
- )
-
+ for links in self._get_chain_links(txn, set(event_chains.keys())):
for chain_id in links:
if chain_id not in event_chains:
continue
_materialize(chain_id, event_chains[chain_id], links, chains)
- chains_to_fetch.difference_update(chains)
-
# Add the initial set of chains, excluding the sequence corresponding to
# initial event.
for chain_id, seq_no in event_chains.items():
@@ -380,6 +333,68 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return results
+ @classmethod
+ def _get_chain_links(
+ cls, txn: LoggingTransaction, chains_to_fetch: Set[int]
+ ) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]:
+ """Fetch all auth chain links from the given set of chains, and all
+ links from those chains, recursively.
+
+ Note: This may return links that are not reachable from the given
+ chains.
+
+ Returns a generator that produces dicts from origin chain ID to 3-tuple
+ of origin sequence number, target chain ID and target sequence number.
+ """
+
+ # This query is structured to first get all chain IDs reachable, and
+ # then pull out all links from those chains. This does pull out more
+ # rows than is strictly necessary, however there isn't a way of
+ # structuring the recursive part of query to pull out the links without
+ # also returning large quantities of redundant data (which can make it a
+ # lot slower).
+ sql = """
+ WITH RECURSIVE links(chain_id) AS (
+ SELECT
+ DISTINCT origin_chain_id
+ FROM event_auth_chain_links WHERE %s
+ UNION
+ SELECT
+ target_chain_id
+ FROM event_auth_chain_links
+ INNER JOIN links ON (chain_id = origin_chain_id)
+ )
+ SELECT
+ origin_chain_id, origin_sequence_number,
+ target_chain_id, target_sequence_number
+ FROM links
+ INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
+ """
+
+ while chains_to_fetch:
+ batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
+ chains_to_fetch.difference_update(batch2)
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "origin_chain_id", batch2
+ )
+ txn.execute(sql % (clause,), args)
+
+ links: Dict[int, List[Tuple[int, int, int]]] = {}
+
+ for (
+ origin_chain_id,
+ origin_sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ ) in txn:
+ links.setdefault(origin_chain_id, []).append(
+ (origin_sequence_number, target_chain_id, target_sequence_number)
+ )
+
+ chains_to_fetch.difference_update(links)
+
+ yield links
+
def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
) -> Set[str]:
@@ -564,53 +579,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Now we look up all links for the chains we have, adding chains that
# are reachable from any event.
- #
- # This query is structured to first get all chain IDs reachable, and
- # then pull out all links from those chains. This does pull out more
- # rows than is strictly necessary, however there isn't a way of
- # structuring the recursive part of query to pull out the links without
- # also returning large quantities of redundant data (which can make it a
- # lot slower).
- sql = """
- WITH RECURSIVE links(chain_id) AS (
- SELECT
- DISTINCT origin_chain_id
- FROM event_auth_chain_links WHERE %s
- UNION
- SELECT
- target_chain_id
- FROM event_auth_chain_links
- INNER JOIN links ON (chain_id = origin_chain_id)
- )
- SELECT
- origin_chain_id, origin_sequence_number,
- target_chain_id, target_sequence_number
- FROM links
- INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
- """
-
- # (We need to take a copy of `seen_chains` as we want to mutate it in
- # the loop)
- chains_to_fetch = set(seen_chains)
- while chains_to_fetch:
- batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
- clause, args = make_in_list_sql_clause(
- txn.database_engine, "origin_chain_id", batch2
- )
- txn.execute(sql % (clause,), args)
-
- links: Dict[int, List[Tuple[int, int, int]]] = {}
-
- for (
- origin_chain_id,
- origin_sequence_number,
- target_chain_id,
- target_sequence_number,
- ) in txn:
- links.setdefault(origin_chain_id, []).append(
- (origin_sequence_number, target_chain_id, target_sequence_number)
- )
+ # (We need to take a copy of `seen_chains` as the function mutates it)
+ for links in self._get_chain_links(txn, set(seen_chains)):
for chains in set_to_chain:
for chain_id in links:
if chain_id not in chains:
@@ -618,7 +589,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
_materialize(chain_id, chains[chain_id], links, chains)
- chains_to_fetch.difference_update(chains)
seen_chains.update(chains)
# Now for each chain we figure out the maximum sequence number reachable
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 91beca6ffc..660c834518 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -53,11 +53,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import (
- AbstractStreamIdGenerator,
- IdGenerator,
- StreamIdGenerator,
-)
+from synapse.storage.util.id_generators import IdGenerator, StreamIdGenerator
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
from synapse.types import JsonDict
from synapse.util import json_encoder, unwrapFirstError
@@ -130,6 +126,8 @@ class PushRulesWorkerStore(
`get_max_push_rules_stream_id` which can be called in the initializer.
"""
+ _push_rules_stream_id_gen: StreamIdGenerator
+
def __init__(
self,
database: DatabasePool,
@@ -138,6 +136,10 @@ class PushRulesWorkerStore(
):
super().__init__(database, db_conn, hs)
+ self._is_push_writer = (
+ hs.get_instance_name() in hs.config.worker.writers.push_rules
+ )
+
# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
self._push_rules_stream_id_gen = StreamIdGenerator(
@@ -145,7 +147,7 @@ class PushRulesWorkerStore(
hs.get_replication_notifier(),
"push_rules_stream",
"stream_id",
- is_writer=hs.config.worker.worker_app is None,
+ is_writer=self._is_push_writer,
)
push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
@@ -162,6 +164,9 @@ class PushRulesWorkerStore(
prefilled_cache=push_rules_prefill,
)
+ self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
+ self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
+
def get_max_push_rules_stream_id(self) -> int:
"""Get the position of the push rules stream.
@@ -383,23 +388,6 @@ class PushRulesWorkerStore(
"get_all_push_rule_updates", get_all_push_rule_updates_txn
)
-
-class PushRuleStore(PushRulesWorkerStore):
- # Because we have write access, this will be a StreamIdGenerator
- # (see PushRulesWorkerStore.__init__)
- _push_rules_stream_id_gen: AbstractStreamIdGenerator
-
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
- super().__init__(database, db_conn, hs)
-
- self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
- self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
-
async def add_push_rule(
self,
user_id: str,
@@ -410,6 +398,9 @@ class PushRuleStore(PushRulesWorkerStore):
before: Optional[str] = None,
after: Optional[str] = None,
) -> None:
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
+
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
async with self._push_rules_stream_id_gen.get_next() as stream_id:
@@ -455,6 +446,9 @@ class PushRuleStore(PushRulesWorkerStore):
before: str,
after: str,
) -> None:
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
+
relative_to_rule = before or after
sql = """
@@ -524,6 +518,9 @@ class PushRuleStore(PushRulesWorkerStore):
conditions_json: str,
actions_json: str,
) -> None:
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
+
if isinstance(self.database_engine, PostgresEngine):
# Postgres doesn't do FOR UPDATE on aggregate functions, so select the rows first
# then re-select the count/max below.
@@ -575,6 +572,9 @@ class PushRuleStore(PushRulesWorkerStore):
actions_json: str,
update_stream: bool = True,
) -> None:
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
+
"""Specialised version of simple_upsert_txn that picks a push_rule_id
using the _push_rule_id_gen if it needs to insert the rule. It assumes
that the "push_rules" table is locked"""
@@ -653,6 +653,8 @@ class PushRuleStore(PushRulesWorkerStore):
user_id: The matrix ID of the push rule owner
rule_id: The rule_id of the rule to be deleted
"""
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
def delete_push_rule_txn(
txn: LoggingTransaction,
@@ -704,6 +706,9 @@ class PushRuleStore(PushRulesWorkerStore):
Raises:
RuleNotFoundException if the rule does not exist.
"""
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
+
async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
@@ -727,6 +732,9 @@ class PushRuleStore(PushRulesWorkerStore):
enabled: bool,
is_default_rule: bool,
) -> None:
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
+
new_id = self._push_rules_enable_id_gen.get_next()
if not is_default_rule:
@@ -796,6 +804,9 @@ class PushRuleStore(PushRulesWorkerStore):
Raises:
RuleNotFoundException if the rule does not exist.
"""
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
+
actions_json = json_encoder.encode(actions)
def set_push_rule_actions_txn(
@@ -865,6 +876,9 @@ class PushRuleStore(PushRulesWorkerStore):
op: str,
data: Optional[JsonDict] = None,
) -> None:
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
+
values = {
"stream_id": stream_id,
"event_stream_ordering": event_stream_ordering,
@@ -882,9 +896,6 @@ class PushRuleStore(PushRulesWorkerStore):
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
- def get_max_push_rules_stream_id(self) -> int:
- return self._push_rules_stream_id_gen.get_current_token()
-
async def copy_push_rule_from_room_to_room(
self, new_room_id: str, user_id: str, rule: PushRule
) -> None:
@@ -895,6 +906,9 @@ class PushRuleStore(PushRulesWorkerStore):
user_id : ID of user the push rule belongs to.
rule: A push rule.
"""
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
+
# Create new rule id
rule_id_scope = "/".join(rule.rule_id.split("/")[:-1])
new_rule_id = rule_id_scope + "/" + new_room_id
@@ -930,6 +944,9 @@ class PushRuleStore(PushRulesWorkerStore):
new_room_id: ID of the new room.
user_id: ID of user to copy push rules for.
"""
+ if not self._is_push_writer:
+ raise Exception("Not a push writer")
+
# Retrieve push rules for this user
user_push_rules = await self.get_push_rules_for_user(user_id)
|