summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/17037.feature1
-rw-r--r--changelog.d/17038.feature1
-rw-r--r--changelog.d/17044.misc1
-rwxr-xr-xdocker/configure_workers_and_start.py8
-rw-r--r--docs/workers.md7
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py9
-rw-r--r--synapse/config/workers.py12
-rw-r--r--synapse/handlers/room_member.py21
-rw-r--r--synapse/replication/http/push.py41
-rw-r--r--synapse/replication/tcp/handler.py7
-rw-r--r--synapse/rest/client/push_rule.py8
-rw-r--r--synapse/storage/databases/main/__init__.py4
-rw-r--r--synapse/storage/databases/main/event_federation.py162
-rw-r--r--synapse/storage/databases/main/push_rule.py69
14 files changed, 215 insertions, 136 deletions
diff --git a/changelog.d/17037.feature b/changelog.d/17037.feature
new file mode 100644

index 0000000000..498221e19e --- /dev/null +++ b/changelog.d/17037.feature
@@ -0,0 +1 @@ +Add support for moving `/pushrules` off of main process. diff --git a/changelog.d/17038.feature b/changelog.d/17038.feature new file mode 100644
index 0000000000..498221e19e --- /dev/null +++ b/changelog.d/17038.feature
@@ -0,0 +1 @@ +Add support for moving `/pushrules` off of main process. diff --git a/changelog.d/17044.misc b/changelog.d/17044.misc new file mode 100644
index 0000000000..a1439752d3 --- /dev/null +++ b/changelog.d/17044.misc
@@ -0,0 +1 @@ +Refactor auth chain fetching to reduce duplication. diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py
index 3917d9ae7e..77534a4f4f 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py
@@ -310,6 +310,13 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { "shared_extra_conf": {}, "worker_extra_conf": "", }, + "push_rules": { + "app": "synapse.app.generic_worker", + "listener_resources": ["client", "replication"], + "endpoint_patterns": ["^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/"], + "shared_extra_conf": {}, + "worker_extra_conf": "", + }, } # Templates for sections that may be inserted multiple times in config files @@ -401,6 +408,7 @@ def add_worker_roles_to_shared_config( "receipts", "to_device", "typing", + "push_rules", ] # Worker-type specific sharding config. Now a single worker can fulfill multiple diff --git a/docs/workers.md b/docs/workers.md
index d19f1a9dea..ab9c1db86b 100644 --- a/docs/workers.md +++ b/docs/workers.md
@@ -532,6 +532,13 @@ the stream writer for the `presence` stream: ^/_matrix/client/(api/v1|r0|v3|unstable)/presence/ +##### The `push_rules` stream + +The following endpoints should be routed directly to the worker configured as +the stream writer for the `push` stream: + + ^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/ + #### Restrict outbound federation traffic to a specific set of workers The diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index a533cad5ae..15507372a4 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py
@@ -60,7 +60,7 @@ from synapse.logging.context import ( ) from synapse.notifier import ReplicationNotifier from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn -from synapse.storage.databases.main import FilteringWorkerStore, PushRuleStore +from synapse.storage.databases.main import FilteringWorkerStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore @@ -77,10 +77,8 @@ from synapse.storage.databases.main.media_repository import ( ) from synapse.storage.databases.main.presence import PresenceBackgroundUpdateStore from synapse.storage.databases.main.profile import ProfileWorkerStore -from synapse.storage.databases.main.pusher import ( - PusherBackgroundUpdatesStore, - PusherWorkerStore, -) +from synapse.storage.databases.main.push_rule import PusherWorkerStore +from synapse.storage.databases.main.pusher import PusherBackgroundUpdatesStore from synapse.storage.databases.main.receipts import ReceiptsBackgroundUpdateStore from synapse.storage.databases.main.registration import ( RegistrationBackgroundUpdateStore, @@ -245,7 +243,6 @@ class Store( AccountDataWorkerStore, FilteringWorkerStore, ProfileWorkerStore, - PushRuleStore, PusherWorkerStore, PusherBackgroundUpdatesStore, PresenceBackgroundUpdateStore, diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index e9c67807e5..7ecf349e4a 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py
@@ -156,6 +156,8 @@ class WriterLocations: can only be a single instance. presence: The instances that write to the presence stream. Currently can only be a single instance. + push_rules: The instances that write to the push stream. Currently + can only be a single instance. """ events: List[str] = attr.ib( @@ -182,6 +184,10 @@ class WriterLocations: default=["master"], converter=_instance_to_list_converter, ) + push_rules: List[str] = attr.ib( + default=["master"], + converter=_instance_to_list_converter, + ) @attr.s(auto_attribs=True) @@ -341,6 +347,7 @@ class WorkerConfig(Config): "account_data", "receipts", "presence", + "push_rules", ): instances = _instance_to_list_converter(getattr(self.writers, stream)) for instance in instances: @@ -378,6 +385,11 @@ class WorkerConfig(Config): "Must only specify one instance to handle `presence` messages." ) + if len(self.writers.push_rules) != 1: + raise ConfigError( + "Must only specify one instance to handle `push` messages." + ) + self.events_shard_config = RoutableShardedWorkerHandlingConfig( self.writers.events ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 9e9f6cd062..601d37341b 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py
@@ -51,6 +51,7 @@ from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.logging import opentracing from synapse.metrics import event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.http.push import ReplicationCopyPusherRestServlet from synapse.storage.databases.main.state_deltas import StateDelta from synapse.types import ( JsonDict, @@ -181,6 +182,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): hs.config.server.forgotten_room_retention_period ) + self._is_push_writer = ( + hs.get_instance_name() in hs.config.worker.writers.push_rules + ) + self._push_writer = hs.config.worker.writers.push_rules[0] + self._copy_push_client = ReplicationCopyPusherRestServlet.make_client(hs) + def _on_user_joined_room(self, event_id: str, room_id: str) -> None: """Notify the rate limiter that a room join has occurred. @@ -1301,9 +1308,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): old_room_id, new_room_id, user_id ) # Copy over push rules - await self.store.copy_push_rules_from_room_to_room_for_user( - old_room_id, new_room_id, user_id - ) + if self._is_push_writer: + await self.store.copy_push_rules_from_room_to_room_for_user( + old_room_id, new_room_id, user_id + ) + else: + await self._copy_push_client( + instance_name=self._push_writer, + user_id=user_id, + old_room_id=old_room_id, + new_room_id=new_room_id, + ) except Exception: logger.exception( "Error copying tags and/or push rules from rooms %s to %s for user %s. " diff --git a/synapse/replication/http/push.py b/synapse/replication/http/push.py
index 8e5641707a..de07e75b46 100644 --- a/synapse/replication/http/push.py +++ b/synapse/replication/http/push.py
@@ -77,5 +77,46 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint): return 200, {} +class ReplicationCopyPusherRestServlet(ReplicationEndpoint): + """Copies push rules from an old room to new room. + + Request format: + + POST /_synapse/replication/copy_push_rules/:user_id/:old_room_id/:new_room_id + + {} + + """ + + NAME = "copy_push_rules" + PATH_ARGS = ("user_id", "old_room_id", "new_room_id") + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self._store = hs.get_datastores().main + + @staticmethod + async def _serialize_payload(user_id: str, old_room_id: str, new_room_id: str) -> JsonDict: # type: ignore[override] + return {} + + async def _handle_request( # type: ignore[override] + self, + request: Request, + content: JsonDict, + user_id: str, + old_room_id: str, + new_room_id: str, + ) -> Tuple[int, JsonDict]: + + await self._store.copy_push_rules_from_room_to_room_for_user( + old_room_id, new_room_id, user_id + ) + + return 200, {} + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReplicationRemovePusherRestServlet(hs).register(http_server) + ReplicationCopyPusherRestServlet(hs).register(http_server) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index ecc12c0b28..72a42cb6cc 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py
@@ -66,6 +66,7 @@ from synapse.replication.tcp.streams import ( FederationStream, PresenceFederationStream, PresenceStream, + PushRulesStream, ReceiptsStream, Stream, ToDeviceStream, @@ -178,6 +179,12 @@ class ReplicationCommandHandler: continue + if isinstance(stream, PushRulesStream): + if hs.get_instance_name() in hs.config.worker.writers.push_rules: + self._streams_to_replicate.append(stream) + + continue + # Only add any other streams if we're on master. if hs.config.worker.worker_app is not None: continue diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index 7d58611abb..af042504c9 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py
@@ -59,12 +59,14 @@ class PushRuleRestServlet(RestServlet): self.auth = hs.get_auth() self.store = hs.get_datastores().main self.notifier = hs.get_notifier() - self._is_worker = hs.config.worker.worker_app is not None + self._is_push_worker = ( + hs.get_instance_name() in hs.config.worker.writers.push_rules + ) self._push_rules_handler = hs.get_push_rules_handler() self._push_rule_linearizer = Linearizer(name="push_rules") async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]: - if self._is_worker: + if not self._is_push_worker: raise Exception("Cannot handle PUT /push_rules on worker") requester = await self.auth.get_user_by_req(request) @@ -137,7 +139,7 @@ class PushRuleRestServlet(RestServlet): async def on_DELETE( self, request: SynapseRequest, path: str ) -> Tuple[int, JsonDict]: - if self._is_worker: + if not self._is_push_worker: raise Exception("Cannot handle DELETE /push_rules on worker") requester = await self.auth.get_user_by_req(request) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index bf779587d9..586e84f2a4 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -63,7 +63,7 @@ from .openid import OpenIdStore from .presence import PresenceStore from .profile import ProfileStore from .purge_events import PurgeEventsStore -from .push_rule import PushRuleStore +from .push_rule import PushRulesWorkerStore from .pusher import PusherStore from .receipts import ReceiptsStore from .registration import RegistrationStore @@ -130,7 +130,6 @@ class DataStore( RejectionsStore, FilteringWorkerStore, PusherStore, - PushRuleStore, ApplicationServiceTransactionStore, EventPushActionsStore, ServerMetricsStore, @@ -140,6 +139,7 @@ class DataStore( SearchStore, TagsStore, AccountDataStore, + PushRulesWorkerStore, StreamWorkerStore, OpenIdStore, ClientIpWorkerStore, diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 846c3f363a..fb132ef090 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py
@@ -27,6 +27,7 @@ from typing import ( Collection, Dict, FrozenSet, + Generator, Iterable, List, Optional, @@ -279,64 +280,16 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # Now we look up all links for the chains we have, adding chains that # are reachable from any event. - # - # This query is structured to first get all chain IDs reachable, and - # then pull out all links from those chains. This does pull out more - # rows than is strictly necessary, however there isn't a way of - # structuring the recursive part of query to pull out the links without - # also returning large quantities of redundant data (which can make it a - # lot slower). - sql = """ - WITH RECURSIVE links(chain_id) AS ( - SELECT - DISTINCT origin_chain_id - FROM event_auth_chain_links WHERE %s - UNION - SELECT - target_chain_id - FROM event_auth_chain_links - INNER JOIN links ON (chain_id = origin_chain_id) - ) - SELECT - origin_chain_id, origin_sequence_number, - target_chain_id, target_sequence_number - FROM links - INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id) - """ # A map from chain ID to max sequence number *reachable* from any event ID. chains: Dict[int, int] = {} - - # Add all linked chains reachable from initial set of chains. - chains_to_fetch = set(event_chains.keys()) - while chains_to_fetch: - batch2 = tuple(itertools.islice(chains_to_fetch, 1000)) - chains_to_fetch.difference_update(batch2) - clause, args = make_in_list_sql_clause( - txn.database_engine, "origin_chain_id", batch2 - ) - txn.execute(sql % (clause,), args) - - links: Dict[int, List[Tuple[int, int, int]]] = {} - - for ( - origin_chain_id, - origin_sequence_number, - target_chain_id, - target_sequence_number, - ) in txn: - links.setdefault(origin_chain_id, []).append( - (origin_sequence_number, target_chain_id, target_sequence_number) - ) - + for links in self._get_chain_links(txn, set(event_chains.keys())): for chain_id in links: if chain_id not in event_chains: continue _materialize(chain_id, event_chains[chain_id], links, chains) - chains_to_fetch.difference_update(chains) - # Add the initial set of chains, excluding the sequence corresponding to # initial event. for chain_id, seq_no in event_chains.items(): @@ -380,6 +333,68 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas return results + @classmethod + def _get_chain_links( + cls, txn: LoggingTransaction, chains_to_fetch: Set[int] + ) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]: + """Fetch all auth chain links from the given set of chains, and all + links from those chains, recursively. + + Note: This may return links that are not reachable from the given + chains. + + Returns a generator that produces dicts from origin chain ID to 3-tuple + of origin sequence number, target chain ID and target sequence number. + """ + + # This query is structured to first get all chain IDs reachable, and + # then pull out all links from those chains. This does pull out more + # rows than is strictly necessary, however there isn't a way of + # structuring the recursive part of query to pull out the links without + # also returning large quantities of redundant data (which can make it a + # lot slower). + sql = """ + WITH RECURSIVE links(chain_id) AS ( + SELECT + DISTINCT origin_chain_id + FROM event_auth_chain_links WHERE %s + UNION + SELECT + target_chain_id + FROM event_auth_chain_links + INNER JOIN links ON (chain_id = origin_chain_id) + ) + SELECT + origin_chain_id, origin_sequence_number, + target_chain_id, target_sequence_number + FROM links + INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id) + """ + + while chains_to_fetch: + batch2 = tuple(itertools.islice(chains_to_fetch, 1000)) + chains_to_fetch.difference_update(batch2) + clause, args = make_in_list_sql_clause( + txn.database_engine, "origin_chain_id", batch2 + ) + txn.execute(sql % (clause,), args) + + links: Dict[int, List[Tuple[int, int, int]]] = {} + + for ( + origin_chain_id, + origin_sequence_number, + target_chain_id, + target_sequence_number, + ) in txn: + links.setdefault(origin_chain_id, []).append( + (origin_sequence_number, target_chain_id, target_sequence_number) + ) + + chains_to_fetch.difference_update(links) + + yield links + def _get_auth_chain_ids_txn( self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool ) -> Set[str]: @@ -564,53 +579,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # Now we look up all links for the chains we have, adding chains that # are reachable from any event. - # - # This query is structured to first get all chain IDs reachable, and - # then pull out all links from those chains. This does pull out more - # rows than is strictly necessary, however there isn't a way of - # structuring the recursive part of query to pull out the links without - # also returning large quantities of redundant data (which can make it a - # lot slower). - sql = """ - WITH RECURSIVE links(chain_id) AS ( - SELECT - DISTINCT origin_chain_id - FROM event_auth_chain_links WHERE %s - UNION - SELECT - target_chain_id - FROM event_auth_chain_links - INNER JOIN links ON (chain_id = origin_chain_id) - ) - SELECT - origin_chain_id, origin_sequence_number, - target_chain_id, target_sequence_number - FROM links - INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id) - """ - - # (We need to take a copy of `seen_chains` as we want to mutate it in - # the loop) - chains_to_fetch = set(seen_chains) - while chains_to_fetch: - batch2 = tuple(itertools.islice(chains_to_fetch, 1000)) - clause, args = make_in_list_sql_clause( - txn.database_engine, "origin_chain_id", batch2 - ) - txn.execute(sql % (clause,), args) - - links: Dict[int, List[Tuple[int, int, int]]] = {} - - for ( - origin_chain_id, - origin_sequence_number, - target_chain_id, - target_sequence_number, - ) in txn: - links.setdefault(origin_chain_id, []).append( - (origin_sequence_number, target_chain_id, target_sequence_number) - ) + # (We need to take a copy of `seen_chains` as the function mutates it) + for links in self._get_chain_links(txn, set(seen_chains)): for chains in set_to_chain: for chain_id in links: if chain_id not in chains: @@ -618,7 +589,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas _materialize(chain_id, chains[chain_id], links, chains) - chains_to_fetch.difference_update(chains) seen_chains.update(chains) # Now for each chain we figure out the maximum sequence number reachable diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 91beca6ffc..660c834518 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -53,11 +53,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException -from synapse.storage.util.id_generators import ( - AbstractStreamIdGenerator, - IdGenerator, - StreamIdGenerator, -) +from synapse.storage.util.id_generators import IdGenerator, StreamIdGenerator from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules from synapse.types import JsonDict from synapse.util import json_encoder, unwrapFirstError @@ -130,6 +126,8 @@ class PushRulesWorkerStore( `get_max_push_rules_stream_id` which can be called in the initializer. """ + _push_rules_stream_id_gen: StreamIdGenerator + def __init__( self, database: DatabasePool, @@ -138,6 +136,10 @@ class PushRulesWorkerStore( ): super().__init__(database, db_conn, hs) + self._is_push_writer = ( + hs.get_instance_name() in hs.config.worker.writers.push_rules + ) + # In the worker store this is an ID tracker which we overwrite in the non-worker # class below that is used on the main process. self._push_rules_stream_id_gen = StreamIdGenerator( @@ -145,7 +147,7 @@ class PushRulesWorkerStore( hs.get_replication_notifier(), "push_rules_stream", "stream_id", - is_writer=hs.config.worker.worker_app is None, + is_writer=self._is_push_writer, ) push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( @@ -162,6 +164,9 @@ class PushRulesWorkerStore( prefilled_cache=push_rules_prefill, ) + self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") + self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") + def get_max_push_rules_stream_id(self) -> int: """Get the position of the push rules stream. @@ -383,23 +388,6 @@ class PushRulesWorkerStore( "get_all_push_rule_updates", get_all_push_rule_updates_txn ) - -class PushRuleStore(PushRulesWorkerStore): - # Because we have write access, this will be a StreamIdGenerator - # (see PushRulesWorkerStore.__init__) - _push_rules_stream_id_gen: AbstractStreamIdGenerator - - def __init__( - self, - database: DatabasePool, - db_conn: LoggingDatabaseConnection, - hs: "HomeServer", - ): - super().__init__(database, db_conn, hs) - - self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") - self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") - async def add_push_rule( self, user_id: str, @@ -410,6 +398,9 @@ class PushRuleStore(PushRulesWorkerStore): before: Optional[str] = None, after: Optional[str] = None, ) -> None: + if not self._is_push_writer: + raise Exception("Not a push writer") + conditions_json = json_encoder.encode(conditions) actions_json = json_encoder.encode(actions) async with self._push_rules_stream_id_gen.get_next() as stream_id: @@ -455,6 +446,9 @@ class PushRuleStore(PushRulesWorkerStore): before: str, after: str, ) -> None: + if not self._is_push_writer: + raise Exception("Not a push writer") + relative_to_rule = before or after sql = """ @@ -524,6 +518,9 @@ class PushRuleStore(PushRulesWorkerStore): conditions_json: str, actions_json: str, ) -> None: + if not self._is_push_writer: + raise Exception("Not a push writer") + if isinstance(self.database_engine, PostgresEngine): # Postgres doesn't do FOR UPDATE on aggregate functions, so select the rows first # then re-select the count/max below. @@ -575,6 +572,9 @@ class PushRuleStore(PushRulesWorkerStore): actions_json: str, update_stream: bool = True, ) -> None: + if not self._is_push_writer: + raise Exception("Not a push writer") + """Specialised version of simple_upsert_txn that picks a push_rule_id using the _push_rule_id_gen if it needs to insert the rule. It assumes that the "push_rules" table is locked""" @@ -653,6 +653,8 @@ class PushRuleStore(PushRulesWorkerStore): user_id: The matrix ID of the push rule owner rule_id: The rule_id of the rule to be deleted """ + if not self._is_push_writer: + raise Exception("Not a push writer") def delete_push_rule_txn( txn: LoggingTransaction, @@ -704,6 +706,9 @@ class PushRuleStore(PushRulesWorkerStore): Raises: RuleNotFoundException if the rule does not exist. """ + if not self._is_push_writer: + raise Exception("Not a push writer") + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -727,6 +732,9 @@ class PushRuleStore(PushRulesWorkerStore): enabled: bool, is_default_rule: bool, ) -> None: + if not self._is_push_writer: + raise Exception("Not a push writer") + new_id = self._push_rules_enable_id_gen.get_next() if not is_default_rule: @@ -796,6 +804,9 @@ class PushRuleStore(PushRulesWorkerStore): Raises: RuleNotFoundException if the rule does not exist. """ + if not self._is_push_writer: + raise Exception("Not a push writer") + actions_json = json_encoder.encode(actions) def set_push_rule_actions_txn( @@ -865,6 +876,9 @@ class PushRuleStore(PushRulesWorkerStore): op: str, data: Optional[JsonDict] = None, ) -> None: + if not self._is_push_writer: + raise Exception("Not a push writer") + values = { "stream_id": stream_id, "event_stream_ordering": event_stream_ordering, @@ -882,9 +896,6 @@ class PushRuleStore(PushRulesWorkerStore): self.push_rules_stream_cache.entity_has_changed, user_id, stream_id ) - def get_max_push_rules_stream_id(self) -> int: - return self._push_rules_stream_id_gen.get_current_token() - async def copy_push_rule_from_room_to_room( self, new_room_id: str, user_id: str, rule: PushRule ) -> None: @@ -895,6 +906,9 @@ class PushRuleStore(PushRulesWorkerStore): user_id : ID of user the push rule belongs to. rule: A push rule. """ + if not self._is_push_writer: + raise Exception("Not a push writer") + # Create new rule id rule_id_scope = "/".join(rule.rule_id.split("/")[:-1]) new_rule_id = rule_id_scope + "/" + new_room_id @@ -930,6 +944,9 @@ class PushRuleStore(PushRulesWorkerStore): new_room_id: ID of the new room. user_id: ID of user to copy push rules for. """ + if not self._is_push_writer: + raise Exception("Not a push writer") + # Retrieve push rules for this user user_push_rules = await self.get_push_rules_for_user(user_id)