summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py9
-rw-r--r--synapse/config/workers.py12
-rw-r--r--synapse/handlers/room_member.py19
-rw-r--r--synapse/replication/http/push.py41
-rw-r--r--synapse/replication/tcp/handler.py7
-rw-r--r--synapse/rest/client/push_rule.py6
-rw-r--r--synapse/storage/databases/main/__init__.py4
-rw-r--r--synapse/storage/databases/main/push_rule.py67
8 files changed, 125 insertions, 40 deletions
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index a533cad5ae..15507372a4 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -60,7 +60,7 @@ from synapse.logging.context import (
 )
 from synapse.notifier import ReplicationNotifier
 from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
-from synapse.storage.databases.main import FilteringWorkerStore, PushRuleStore
+from synapse.storage.databases.main import FilteringWorkerStore
 from synapse.storage.databases.main.account_data import AccountDataWorkerStore
 from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
 from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
@@ -77,10 +77,8 @@ from synapse.storage.databases.main.media_repository import (
 )
 from synapse.storage.databases.main.presence import PresenceBackgroundUpdateStore
 from synapse.storage.databases.main.profile import ProfileWorkerStore
-from synapse.storage.databases.main.pusher import (
-    PusherBackgroundUpdatesStore,
-    PusherWorkerStore,
-)
+from synapse.storage.databases.main.push_rule import PusherWorkerStore
+from synapse.storage.databases.main.pusher import PusherBackgroundUpdatesStore
 from synapse.storage.databases.main.receipts import ReceiptsBackgroundUpdateStore
 from synapse.storage.databases.main.registration import (
     RegistrationBackgroundUpdateStore,
@@ -245,7 +243,6 @@ class Store(
     AccountDataWorkerStore,
     FilteringWorkerStore,
     ProfileWorkerStore,
-    PushRuleStore,
     PusherWorkerStore,
     PusherBackgroundUpdatesStore,
     PresenceBackgroundUpdateStore,
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index e9c67807e5..9f81a73d6f 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -156,6 +156,8 @@ class WriterLocations:
             can only be a single instance.
         presence: The instances that write to the presence stream. Currently
             can only be a single instance.
+        push: The instances that write to the push stream. Currently
+            can only be a single instance.
     """
 
     events: List[str] = attr.ib(
@@ -182,6 +184,10 @@ class WriterLocations:
         default=["master"],
         converter=_instance_to_list_converter,
     )
+    push: List[str] = attr.ib(
+        default=["master"],
+        converter=_instance_to_list_converter,
+    )
 
 
 @attr.s(auto_attribs=True)
@@ -341,6 +347,7 @@ class WorkerConfig(Config):
             "account_data",
             "receipts",
             "presence",
+            "push",
         ):
             instances = _instance_to_list_converter(getattr(self.writers, stream))
             for instance in instances:
@@ -378,6 +385,11 @@ class WorkerConfig(Config):
                 "Must only specify one instance to handle `presence` messages."
             )
 
+        if len(self.writers.push) != 1:
+            raise ConfigError(
+                "Must only specify one instance to handle `push` messages."
+            )
+
         self.events_shard_config = RoutableShardedWorkerHandlingConfig(
             self.writers.events
         )
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 9e9f6cd062..ee2e807afc 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -51,6 +51,7 @@ from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
 from synapse.logging import opentracing
 from synapse.metrics import event_processing_positions
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.http.push import ReplicationCopyPusherRestServlet
 from synapse.storage.databases.main.state_deltas import StateDelta
 from synapse.types import (
     JsonDict,
@@ -181,6 +182,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             hs.config.server.forgotten_room_retention_period
         )
 
+        self._is_push_writer = hs.get_instance_name() in hs.config.worker.writers.push
+        self._push_writer = hs.config.worker.writers.push[0]
+        self._copy_push_client = ReplicationCopyPusherRestServlet.make_client(hs)
+
     def _on_user_joined_room(self, event_id: str, room_id: str) -> None:
         """Notify the rate limiter that a room join has occurred.
 
@@ -1301,9 +1306,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     old_room_id, new_room_id, user_id
                 )
                 # Copy over push rules
-                await self.store.copy_push_rules_from_room_to_room_for_user(
-                    old_room_id, new_room_id, user_id
-                )
+                if self._is_push_writer:
+                    await self.store.copy_push_rules_from_room_to_room_for_user(
+                        old_room_id, new_room_id, user_id
+                    )
+                else:
+                    await self._copy_push_client(
+                        instance_name=self._push_writer,
+                        user_id=user_id,
+                        old_room_id=old_room_id,
+                        new_room_id=new_room_id,
+                    )
             except Exception:
                 logger.exception(
                     "Error copying tags and/or push rules from rooms %s to %s for user %s. "
diff --git a/synapse/replication/http/push.py b/synapse/replication/http/push.py
index 8e5641707a..de07e75b46 100644
--- a/synapse/replication/http/push.py
+++ b/synapse/replication/http/push.py
@@ -77,5 +77,46 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
         return 200, {}
 
 
+class ReplicationCopyPusherRestServlet(ReplicationEndpoint):
+    """Copies push rules from an old room to new room.
+
+    Request format:
+
+        POST /_synapse/replication/copy_push_rules/:user_id/:old_room_id/:new_room_id
+
+        {}
+
+    """
+
+    NAME = "copy_push_rules"
+    PATH_ARGS = ("user_id", "old_room_id", "new_room_id")
+    CACHE = False
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+
+        self._store = hs.get_datastores().main
+
+    @staticmethod
+    async def _serialize_payload(user_id: str, old_room_id: str, new_room_id: str) -> JsonDict:  # type: ignore[override]
+        return {}
+
+    async def _handle_request(  # type: ignore[override]
+        self,
+        request: Request,
+        content: JsonDict,
+        user_id: str,
+        old_room_id: str,
+        new_room_id: str,
+    ) -> Tuple[int, JsonDict]:
+
+        await self._store.copy_push_rules_from_room_to_room_for_user(
+            old_room_id, new_room_id, user_id
+        )
+
+        return 200, {}
+
+
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     ReplicationRemovePusherRestServlet(hs).register(http_server)
+    ReplicationCopyPusherRestServlet(hs).register(http_server)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index ecc12c0b28..4342d6ce70 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -66,6 +66,7 @@ from synapse.replication.tcp.streams import (
     FederationStream,
     PresenceFederationStream,
     PresenceStream,
+    PushRulesStream,
     ReceiptsStream,
     Stream,
     ToDeviceStream,
@@ -178,6 +179,12 @@ class ReplicationCommandHandler:
 
                 continue
 
+            if isinstance(stream, PushRulesStream):
+                if hs.get_instance_name() in hs.config.worker.writers.push:
+                    self._streams_to_replicate.append(stream)
+
+                continue
+
             # Only add any other streams if we're on master.
             if hs.config.worker.worker_app is not None:
                 continue
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index 7d58611abb..9c9bfc9794 100644
--- a/synapse/rest/client/push_rule.py
+++ b/synapse/rest/client/push_rule.py
@@ -59,12 +59,12 @@ class PushRuleRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
         self.notifier = hs.get_notifier()
-        self._is_worker = hs.config.worker.worker_app is not None
+        self._is_push_worker = hs.get_instance_name() in hs.config.worker.writers.push
         self._push_rules_handler = hs.get_push_rules_handler()
         self._push_rule_linearizer = Linearizer(name="push_rules")
 
     async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
-        if self._is_worker:
+        if not self._is_push_worker:
             raise Exception("Cannot handle PUT /push_rules on worker")
 
         requester = await self.auth.get_user_by_req(request)
@@ -137,7 +137,7 @@ class PushRuleRestServlet(RestServlet):
     async def on_DELETE(
         self, request: SynapseRequest, path: str
     ) -> Tuple[int, JsonDict]:
-        if self._is_worker:
+        if not self._is_push_worker:
             raise Exception("Cannot handle DELETE /push_rules on worker")
 
         requester = await self.auth.get_user_by_req(request)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index bf779587d9..586e84f2a4 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -63,7 +63,7 @@ from .openid import OpenIdStore
 from .presence import PresenceStore
 from .profile import ProfileStore
 from .purge_events import PurgeEventsStore
-from .push_rule import PushRuleStore
+from .push_rule import PushRulesWorkerStore
 from .pusher import PusherStore
 from .receipts import ReceiptsStore
 from .registration import RegistrationStore
@@ -130,7 +130,6 @@ class DataStore(
     RejectionsStore,
     FilteringWorkerStore,
     PusherStore,
-    PushRuleStore,
     ApplicationServiceTransactionStore,
     EventPushActionsStore,
     ServerMetricsStore,
@@ -140,6 +139,7 @@ class DataStore(
     SearchStore,
     TagsStore,
     AccountDataStore,
+    PushRulesWorkerStore,
     StreamWorkerStore,
     OpenIdStore,
     ClientIpWorkerStore,
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 91beca6ffc..ed734f03ac 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -53,11 +53,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import (
-    AbstractStreamIdGenerator,
-    IdGenerator,
-    StreamIdGenerator,
-)
+from synapse.storage.util.id_generators import IdGenerator, StreamIdGenerator
 from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
 from synapse.types import JsonDict
 from synapse.util import json_encoder, unwrapFirstError
@@ -130,6 +126,8 @@ class PushRulesWorkerStore(
     `get_max_push_rules_stream_id` which can be called in the initializer.
     """
 
+    _push_rules_stream_id_gen: StreamIdGenerator
+
     def __init__(
         self,
         database: DatabasePool,
@@ -138,6 +136,8 @@ class PushRulesWorkerStore(
     ):
         super().__init__(database, db_conn, hs)
 
+        self._is_push_writer = hs.get_instance_name() in hs.config.worker.writers.push
+
         # In the worker store this is an ID tracker which we overwrite in the non-worker
         # class below that is used on the main process.
         self._push_rules_stream_id_gen = StreamIdGenerator(
@@ -145,7 +145,7 @@ class PushRulesWorkerStore(
             hs.get_replication_notifier(),
             "push_rules_stream",
             "stream_id",
-            is_writer=hs.config.worker.worker_app is None,
+            is_writer=self._is_push_writer,
         )
 
         push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
@@ -162,6 +162,9 @@ class PushRulesWorkerStore(
             prefilled_cache=push_rules_prefill,
         )
 
+        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
+        self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
+
     def get_max_push_rules_stream_id(self) -> int:
         """Get the position of the push rules stream.
 
@@ -383,23 +386,6 @@ class PushRulesWorkerStore(
             "get_all_push_rule_updates", get_all_push_rule_updates_txn
         )
 
-
-class PushRuleStore(PushRulesWorkerStore):
-    # Because we have write access, this will be a StreamIdGenerator
-    # (see PushRulesWorkerStore.__init__)
-    _push_rules_stream_id_gen: AbstractStreamIdGenerator
-
-    def __init__(
-        self,
-        database: DatabasePool,
-        db_conn: LoggingDatabaseConnection,
-        hs: "HomeServer",
-    ):
-        super().__init__(database, db_conn, hs)
-
-        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
-        self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
-
     async def add_push_rule(
         self,
         user_id: str,
@@ -410,6 +396,9 @@ class PushRuleStore(PushRulesWorkerStore):
         before: Optional[str] = None,
         after: Optional[str] = None,
     ) -> None:
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
+
         conditions_json = json_encoder.encode(conditions)
         actions_json = json_encoder.encode(actions)
         async with self._push_rules_stream_id_gen.get_next() as stream_id:
@@ -455,6 +444,9 @@ class PushRuleStore(PushRulesWorkerStore):
         before: str,
         after: str,
     ) -> None:
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
+
         relative_to_rule = before or after
 
         sql = """
@@ -524,6 +516,9 @@ class PushRuleStore(PushRulesWorkerStore):
         conditions_json: str,
         actions_json: str,
     ) -> None:
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
+
         if isinstance(self.database_engine, PostgresEngine):
             # Postgres doesn't do FOR UPDATE on aggregate functions, so select the rows first
             # then re-select the count/max below.
@@ -575,6 +570,9 @@ class PushRuleStore(PushRulesWorkerStore):
         actions_json: str,
         update_stream: bool = True,
     ) -> None:
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
+
         """Specialised version of simple_upsert_txn that picks a push_rule_id
         using the _push_rule_id_gen if it needs to insert the rule. It assumes
         that the "push_rules" table is locked"""
@@ -653,6 +651,8 @@ class PushRuleStore(PushRulesWorkerStore):
             user_id: The matrix ID of the push rule owner
             rule_id: The rule_id of the rule to be deleted
         """
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
 
         def delete_push_rule_txn(
             txn: LoggingTransaction,
@@ -704,6 +704,9 @@ class PushRuleStore(PushRulesWorkerStore):
         Raises:
             RuleNotFoundException if the rule does not exist.
         """
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
+
         async with self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
             await self.db_pool.runInteraction(
@@ -727,6 +730,9 @@ class PushRuleStore(PushRulesWorkerStore):
         enabled: bool,
         is_default_rule: bool,
     ) -> None:
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
+
         new_id = self._push_rules_enable_id_gen.get_next()
 
         if not is_default_rule:
@@ -796,6 +802,9 @@ class PushRuleStore(PushRulesWorkerStore):
         Raises:
             RuleNotFoundException if the rule does not exist.
         """
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
+
         actions_json = json_encoder.encode(actions)
 
         def set_push_rule_actions_txn(
@@ -865,6 +874,9 @@ class PushRuleStore(PushRulesWorkerStore):
         op: str,
         data: Optional[JsonDict] = None,
     ) -> None:
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
+
         values = {
             "stream_id": stream_id,
             "event_stream_ordering": event_stream_ordering,
@@ -882,9 +894,6 @@ class PushRuleStore(PushRulesWorkerStore):
             self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
         )
 
-    def get_max_push_rules_stream_id(self) -> int:
-        return self._push_rules_stream_id_gen.get_current_token()
-
     async def copy_push_rule_from_room_to_room(
         self, new_room_id: str, user_id: str, rule: PushRule
     ) -> None:
@@ -895,6 +904,9 @@ class PushRuleStore(PushRulesWorkerStore):
             user_id : ID of user the push rule belongs to.
             rule: A push rule.
         """
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
+
         # Create new rule id
         rule_id_scope = "/".join(rule.rule_id.split("/")[:-1])
         new_rule_id = rule_id_scope + "/" + new_room_id
@@ -930,6 +942,9 @@ class PushRuleStore(PushRulesWorkerStore):
             new_room_id: ID of the new room.
             user_id: ID of user to copy push rules for.
         """
+        if not self._is_push_writer:
+            raise Exception("Not a push writer")
+
         # Retrieve push rules for this user
         user_push_rules = await self.get_push_rules_for_user(user_id)