summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2022-09-21 15:39:01 +0100
committerGitHub <noreply@github.com>2022-09-21 14:39:01 +0000
commit8ae42ab8fa3c6b52d74c24daa7ca75a478fa4fbb (patch)
treedeb3a81241a31a811c26e5bfd64ac56cfc1e83aa /synapse
parentAdd cache invalidation across workers to module API (#13667) (diff)
downloadsynapse-8ae42ab8fa3c6b52d74c24daa7ca75a478fa4fbb.tar.xz
Support enabling/disabling pushers (from MSC3881) (#13799)
Partial implementation of MSC3881
Diffstat (limited to 'synapse')
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py1
-rw-r--r--synapse/config/experimental.py3
-rw-r--r--synapse/handlers/register.py4
-rw-r--r--synapse/push/__init__.py2
-rw-r--r--synapse/push/pusherpool.py81
-rw-r--r--synapse/replication/tcp/client.py10
-rw-r--r--synapse/rest/admin/users.py4
-rw-r--r--synapse/rest/client/pusher.py18
-rw-r--r--synapse/storage/databases/main/pusher.py69
-rw-r--r--synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql16
10 files changed, 154 insertions, 54 deletions
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 30983c47fb..450ba462ba 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -111,6 +111,7 @@ BOOLEAN_COLUMNS = {
     "e2e_fallback_keys_json": ["used"],
     "access_tokens": ["used"],
     "device_lists_changes_in_room": ["converted_to_destinations"],
+    "pushers": ["enabled"],
 }
 
 
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 702b81e636..f4541a8db0 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -93,3 +93,6 @@ class ExperimentalConfig(Config):
 
         # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
         self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)
+
+        # MSC3881: Remotely toggle push notifications for another client
+        self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 20ec22105a..cfcadb34db 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -997,7 +997,7 @@ class RegistrationHandler:
             assert user_tuple
             token_id = user_tuple.token_id
 
-            await self.pusher_pool.add_pusher(
+            await self.pusher_pool.add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="email",
@@ -1005,7 +1005,7 @@ class RegistrationHandler:
                 app_display_name="Email Notifications",
                 device_display_name=threepid["address"],
                 pushkey=threepid["address"],
-                lang=None,  # We don't know a user's language here
+                lang=None,
                 data={},
             )
 
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 57c4d70466..ac99d35a7e 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -116,6 +116,7 @@ class PusherConfig:
     last_stream_ordering: int
     last_success: Optional[int]
     failing_since: Optional[int]
+    enabled: bool
 
     def as_dict(self) -> Dict[str, Any]:
         """Information that can be retrieved about a pusher after creation."""
@@ -128,6 +129,7 @@ class PusherConfig:
             "lang": self.lang,
             "profile_tag": self.profile_tag,
             "pushkey": self.pushkey,
+            "enabled": self.enabled,
         }
 
 
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 1e0ef44fc7..2597898cf4 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -94,7 +94,7 @@ class PusherPool:
             return
         run_as_background_process("start_pushers", self._start_pushers)
 
-    async def add_pusher(
+    async def add_or_update_pusher(
         self,
         user_id: str,
         access_token: Optional[int],
@@ -106,6 +106,7 @@ class PusherPool:
         lang: Optional[str],
         data: JsonDict,
         profile_tag: str = "",
+        enabled: bool = True,
     ) -> Optional[Pusher]:
         """Creates a new pusher and adds it to the pool
 
@@ -147,9 +148,20 @@ class PusherPool:
                 last_stream_ordering=last_stream_ordering,
                 last_success=None,
                 failing_since=None,
+                enabled=enabled,
             )
         )
 
+        # Before we actually persist the pusher, we check if the user already has one
+        # for this app ID and pushkey. If so, we want to keep the access token in place,
+        # since this could be one device modifying (e.g. enabling/disabling) another
+        # device's pusher.
+        existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey(
+            user_id, app_id, pushkey
+        )
+        if existing_config:
+            access_token = existing_config.access_token
+
         await self.store.add_pusher(
             user_id=user_id,
             access_token=access_token,
@@ -163,8 +175,9 @@ class PusherPool:
             data=data,
             last_stream_ordering=last_stream_ordering,
             profile_tag=profile_tag,
+            enabled=enabled,
         )
-        pusher = await self.start_pusher_by_id(app_id, pushkey, user_id)
+        pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id)
 
         return pusher
 
@@ -276,10 +289,25 @@ class PusherPool:
         except Exception:
             logger.exception("Exception in pusher on_new_receipts")
 
-    async def start_pusher_by_id(
+    async def _get_pusher_config_for_user_by_app_id_and_pushkey(
+        self, user_id: str, app_id: str, pushkey: str
+    ) -> Optional[PusherConfig]:
+        resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
+
+        pusher_config = None
+        for r in resultlist:
+            if r.user_name == user_id:
+                pusher_config = r
+
+        return pusher_config
+
+    async def process_pusher_change_by_id(
         self, app_id: str, pushkey: str, user_id: str
     ) -> Optional[Pusher]:
-        """Look up the details for the given pusher, and start it
+        """Look up the details for the given pusher, and either start it if its
+        "enabled" flag is True, or try to stop it otherwise.
+
+        If the pusher is new and its "enabled" flag is False, the stop is a noop.
 
         Returns:
             The pusher started, if any
@@ -290,12 +318,13 @@ class PusherPool:
         if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
             return None
 
-        resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
+        pusher_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey(
+            user_id, app_id, pushkey
+        )
 
-        pusher_config = None
-        for r in resultlist:
-            if r.user_name == user_id:
-                pusher_config = r
+        if pusher_config and not pusher_config.enabled:
+            self.maybe_stop_pusher(app_id, pushkey, user_id)
+            return None
 
         pusher = None
         if pusher_config:
@@ -305,7 +334,7 @@ class PusherPool:
 
     async def _start_pushers(self) -> None:
         """Start all the pushers"""
-        pushers = await self.store.get_all_pushers()
+        pushers = await self.store.get_enabled_pushers()
 
         # Stagger starting up the pushers so we don't completely drown the
         # process on start up.
@@ -363,6 +392,8 @@ class PusherPool:
 
         synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc()
 
+        logger.info("Starting pusher %s / %s", pusher.user_id, appid_pushkey)
+
         # Check if there *may* be push to process. We do this as this check is a
         # lot cheaper to do than actually fetching the exact rows we need to
         # push.
@@ -382,16 +413,7 @@ class PusherPool:
         return pusher
 
     async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
-        appid_pushkey = "%s:%s" % (app_id, pushkey)
-
-        byuser = self.pushers.get(user_id, {})
-
-        if appid_pushkey in byuser:
-            logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
-            pusher = byuser.pop(appid_pushkey)
-            pusher.on_stop()
-
-            synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
+        self.maybe_stop_pusher(app_id, pushkey, user_id)
 
         # We can only delete pushers on master.
         if self._remove_pusher_client:
@@ -402,3 +424,22 @@ class PusherPool:
             await self.store.delete_pusher_by_app_id_pushkey_user_id(
                 app_id, pushkey, user_id
             )
+
+    def maybe_stop_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
+        """Stops a pusher with the given app ID and push key if one is running.
+
+        Args:
+            app_id: the pusher's app ID.
+            pushkey: the pusher's push key.
+            user_id: the user the pusher belongs to. Only used for logging.
+        """
+        appid_pushkey = "%s:%s" % (app_id, pushkey)
+
+        byuser = self.pushers.get(user_id, {})
+
+        if appid_pushkey in byuser:
+            logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
+            pusher = byuser.pop(appid_pushkey)
+            pusher.on_stop()
+
+            synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e4f2201c92..cf9cd6833b 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -189,7 +189,9 @@ class ReplicationDataHandler:
                 if row.deleted:
                     self.stop_pusher(row.user_id, row.app_id, row.pushkey)
                 else:
-                    await self.start_pusher(row.user_id, row.app_id, row.pushkey)
+                    await self.process_pusher_change(
+                        row.user_id, row.app_id, row.pushkey
+                    )
         elif stream_name == EventsStream.NAME:
             # We shouldn't get multiple rows per token for events stream, so
             # we don't need to optimise this for multiple rows.
@@ -334,13 +336,15 @@ class ReplicationDataHandler:
         logger.info("Stopping pusher %r / %r", user_id, key)
         pusher.on_stop()
 
-    async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
+    async def process_pusher_change(
+        self, user_id: str, app_id: str, pushkey: str
+    ) -> None:
         if not self._notify_pushers:
             return
 
         key = "%s:%s" % (app_id, pushkey)
         logger.info("Starting pusher %r / %r", user_id, key)
-        await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
+        await self._pusher_pool.process_pusher_change_by_id(app_id, pushkey, user_id)
 
 
 class FederationSenderHandler:
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 2ca6b2d08a..1274773d7e 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -375,7 +375,7 @@ class UserRestServletV2(RestServlet):
                         and self.hs.config.email.email_notif_for_new_users
                         and medium == "email"
                     ):
-                        await self.pusher_pool.add_pusher(
+                        await self.pusher_pool.add_or_update_pusher(
                             user_id=user_id,
                             access_token=None,
                             kind="email",
@@ -383,7 +383,7 @@ class UserRestServletV2(RestServlet):
                             app_display_name="Email Notifications",
                             device_display_name=address,
                             pushkey=address,
-                            lang=None,  # We don't know a user's language here
+                            lang=None,
                             data={},
                         )
 
diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py
index 9a1f10f4be..c9f76125dc 100644
--- a/synapse/rest/client/pusher.py
+++ b/synapse/rest/client/pusher.py
@@ -42,6 +42,7 @@ class PushersRestServlet(RestServlet):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
+        self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
@@ -51,9 +52,14 @@ class PushersRestServlet(RestServlet):
             user.to_string()
         )
 
-        filtered_pushers = [p.as_dict() for p in pushers]
+        pusher_dicts = [p.as_dict() for p in pushers]
 
-        return 200, {"pushers": filtered_pushers}
+        for pusher in pusher_dicts:
+            if self._msc3881_enabled:
+                pusher["org.matrix.msc3881.enabled"] = pusher["enabled"]
+            del pusher["enabled"]
+
+        return 200, {"pushers": pusher_dicts}
 
 
 class PushersSetRestServlet(RestServlet):
@@ -65,6 +71,7 @@ class PushersSetRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.notifier = hs.get_notifier()
         self.pusher_pool = self.hs.get_pusherpool()
+        self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
@@ -103,6 +110,10 @@ class PushersSetRestServlet(RestServlet):
         if "append" in content:
             append = content["append"]
 
+        enabled = True
+        if self._msc3881_enabled and "org.matrix.msc3881.enabled" in content:
+            enabled = content["org.matrix.msc3881.enabled"]
+
         if not append:
             await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
                 app_id=content["app_id"],
@@ -111,7 +122,7 @@ class PushersSetRestServlet(RestServlet):
             )
 
         try:
-            await self.pusher_pool.add_pusher(
+            await self.pusher_pool.add_or_update_pusher(
                 user_id=user.to_string(),
                 access_token=requester.access_token_id,
                 kind=content["kind"],
@@ -122,6 +133,7 @@ class PushersSetRestServlet(RestServlet):
                 lang=content["lang"],
                 data=content["data"],
                 profile_tag=content.get("profile_tag", ""),
+                enabled=enabled,
             )
         except PusherConfigException as pce:
             raise SynapseError(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index bd0cfa7f32..ee55b8c4a9 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -89,6 +89,11 @@ class PusherWorkerStore(SQLBaseStore):
                 )
                 continue
 
+            # If we're using SQLite, then boolean values are integers. This is
+            # troublesome since some code using the return value of this method might
+            # expect it to be a boolean, or will expose it to clients (in responses).
+            r["enabled"] = bool(r["enabled"])
+
             yield PusherConfig(**r)
 
     async def get_pushers_by_app_id_and_pushkey(
@@ -100,38 +105,52 @@ class PusherWorkerStore(SQLBaseStore):
         return await self.get_pushers_by({"user_name": user_id})
 
     async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]:
-        ret = await self.db_pool.simple_select_list(
-            "pushers",
-            keyvalues,
-            [
-                "id",
-                "user_name",
-                "access_token",
-                "profile_tag",
-                "kind",
-                "app_id",
-                "app_display_name",
-                "device_display_name",
-                "pushkey",
-                "ts",
-                "lang",
-                "data",
-                "last_stream_ordering",
-                "last_success",
-                "failing_since",
-            ],
+        """Retrieve pushers that match the given criteria.
+
+        Args:
+            keyvalues: A {column: value} dictionary.
+
+        Returns:
+            The pushers for which the given columns have the given values.
+        """
+
+        def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+            # We could technically use simple_select_list here, but we need to call
+            # COALESCE on the 'enabled' column. While it is technically possible to give
+            # simple_select_list the whole `COALESCE(...) AS ...` as a column name, it
+            # feels a bit hacky, so it's probably better to just inline the query.
+            sql = """
+            SELECT
+                id, user_name, access_token, profile_tag, kind, app_id,
+                app_display_name, device_display_name, pushkey, ts, lang, data,
+                last_stream_ordering, last_success, failing_since,
+                COALESCE(enabled, TRUE) AS enabled
+            FROM pushers
+            """
+
+            sql += "WHERE %s" % (" AND ".join("%s = ?" % (k,) for k in keyvalues),)
+
+            txn.execute(sql, list(keyvalues.values()))
+
+            return self.db_pool.cursor_to_dict(txn)
+
+        ret = await self.db_pool.runInteraction(
             desc="get_pushers_by",
+            func=get_pushers_by_txn,
         )
+
         return self._decode_pushers_rows(ret)
 
-    async def get_all_pushers(self) -> Iterator[PusherConfig]:
-        def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]:
-            txn.execute("SELECT * FROM pushers")
+    async def get_enabled_pushers(self) -> Iterator[PusherConfig]:
+        def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]:
+            txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)")
             rows = self.db_pool.cursor_to_dict(txn)
 
             return self._decode_pushers_rows(rows)
 
-        return await self.db_pool.runInteraction("get_all_pushers", get_pushers)
+        return await self.db_pool.runInteraction(
+            "get_enabled_pushers", get_enabled_pushers_txn
+        )
 
     async def get_all_updated_pushers_rows(
         self, instance_name: str, last_id: int, current_id: int, limit: int
@@ -476,6 +495,7 @@ class PusherStore(PusherWorkerStore):
         data: Optional[JsonDict],
         last_stream_ordering: int,
         profile_tag: str = "",
+        enabled: bool = True,
     ) -> None:
         async with self._pushers_id_gen.get_next() as stream_id:
             # no need to lock because `pushers` has a unique key on
@@ -494,6 +514,7 @@ class PusherStore(PusherWorkerStore):
                     "last_stream_ordering": last_stream_ordering,
                     "profile_tag": profile_tag,
                     "id": stream_id,
+                    "enabled": enabled,
                 },
                 desc="add_pusher",
                 lock=False,
diff --git a/synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql b/synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql
new file mode 100644
index 0000000000..dba3b4900b
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql
@@ -0,0 +1,16 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE pushers ADD COLUMN enabled BOOLEAN;
\ No newline at end of file