summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13831.feature1
-rw-r--r--synapse/push/__init__.py2
-rw-r--r--synapse/push/pusherpool.py10
-rw-r--r--synapse/rest/client/pusher.py3
-rw-r--r--synapse/storage/databases/main/pusher.py73
-rw-r--r--synapse/storage/schema/main/delta/73/03pusher_device_id.sql20
-rw-r--r--tests/push/test_http.py55
7 files changed, 154 insertions, 10 deletions
diff --git a/changelog.d/13831.feature b/changelog.d/13831.feature
new file mode 100644
index 0000000000..6c8e5cffe2
--- /dev/null
+++ b/changelog.d/13831.feature
@@ -0,0 +1 @@
+Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881).
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index ac99d35a7e..a0c760239d 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -117,6 +117,7 @@ class PusherConfig:
     last_success: Optional[int]
     failing_since: Optional[int]
     enabled: bool
+    device_id: Optional[str]
 
     def as_dict(self) -> Dict[str, Any]:
         """Information that can be retrieved about a pusher after creation."""
@@ -130,6 +131,7 @@ class PusherConfig:
             "profile_tag": self.profile_tag,
             "pushkey": self.pushkey,
             "enabled": self.enabled,
+            "device_id": self.device_id,
         }
 
 
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 2597898cf4..e2648cbc93 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -107,6 +107,7 @@ class PusherPool:
         data: JsonDict,
         profile_tag: str = "",
         enabled: bool = True,
+        device_id: Optional[str] = None,
     ) -> Optional[Pusher]:
         """Creates a new pusher and adds it to the pool
 
@@ -149,18 +150,20 @@ class PusherPool:
                 last_success=None,
                 failing_since=None,
                 enabled=enabled,
+                device_id=device_id,
             )
         )
 
         # Before we actually persist the pusher, we check if the user already has one
-        # for this app ID and pushkey. If so, we want to keep the access token in place,
-        # since this could be one device modifying (e.g. enabling/disabling) another
-        # device's pusher.
+        # this app ID and pushkey. If so, we want to keep the access token and device ID
+        # in place, since this could be one device modifying (e.g. enabling/disabling)
+        # another device's pusher.
         existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey(
             user_id, app_id, pushkey
         )
         if existing_config:
             access_token = existing_config.access_token
+            device_id = existing_config.device_id
 
         await self.store.add_pusher(
             user_id=user_id,
@@ -176,6 +179,7 @@ class PusherPool:
             last_stream_ordering=last_stream_ordering,
             profile_tag=profile_tag,
             enabled=enabled,
+            device_id=device_id,
         )
         pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id)
 
diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py
index c9f76125dc..975eef2144 100644
--- a/synapse/rest/client/pusher.py
+++ b/synapse/rest/client/pusher.py
@@ -57,7 +57,9 @@ class PushersRestServlet(RestServlet):
         for pusher in pusher_dicts:
             if self._msc3881_enabled:
                 pusher["org.matrix.msc3881.enabled"] = pusher["enabled"]
+                pusher["org.matrix.msc3881.device_id"] = pusher["device_id"]
             del pusher["enabled"]
+            del pusher["device_id"]
 
         return 200, {"pushers": pusher_dicts}
 
@@ -134,6 +136,7 @@ class PushersSetRestServlet(RestServlet):
                 data=content["data"],
                 profile_tag=content.get("profile_tag", ""),
                 enabled=enabled,
+                device_id=requester.device_id,
             )
         except PusherConfigException as pce:
             raise SynapseError(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index ee55b8c4a9..01206950a9 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -124,7 +124,7 @@ class PusherWorkerStore(SQLBaseStore):
                 id, user_name, access_token, profile_tag, kind, app_id,
                 app_display_name, device_display_name, pushkey, ts, lang, data,
                 last_stream_ordering, last_success, failing_since,
-                COALESCE(enabled, TRUE) AS enabled
+                COALESCE(enabled, TRUE) AS enabled, device_id
             FROM pushers
             """
 
@@ -477,7 +477,74 @@ class PusherWorkerStore(SQLBaseStore):
         return number_deleted
 
 
-class PusherStore(PusherWorkerStore):
+class PusherBackgroundUpdatesStore(SQLBaseStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+
+        self.db_pool.updates.register_background_update_handler(
+            "set_device_id_for_pushers", self._set_device_id_for_pushers
+        )
+
+    async def _set_device_id_for_pushers(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """Background update to populate the device_id column of the pushers table."""
+        last_pusher_id = progress.get("pusher_id", 0)
+
+        def set_device_id_for_pushers_txn(txn: LoggingTransaction) -> int:
+            txn.execute(
+                """
+                    SELECT p.id, at.device_id
+                    FROM pushers AS p
+                    INNER JOIN access_tokens AS at
+                        ON p.access_token = at.id
+                    WHERE
+                        p.access_token IS NOT NULL
+                        AND at.device_id IS NOT NULL
+                        AND p.id > ?
+                    ORDER BY p.id
+                    LIMIT ?
+                """,
+                (last_pusher_id, batch_size),
+            )
+
+            rows = self.db_pool.cursor_to_dict(txn)
+            if len(rows) == 0:
+                return 0
+
+            self.db_pool.simple_update_many_txn(
+                txn=txn,
+                table="pushers",
+                key_names=("id",),
+                key_values=[(row["id"],) for row in rows],
+                value_names=("device_id",),
+                value_values=[(row["device_id"],) for row in rows],
+            )
+
+            self.db_pool.updates._background_update_progress_txn(
+                txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["id"]}
+            )
+
+            return len(rows)
+
+        nb_processed = await self.db_pool.runInteraction(
+            "set_device_id_for_pushers", set_device_id_for_pushers_txn
+        )
+
+        if nb_processed < batch_size:
+            await self.db_pool.updates._end_background_update(
+                "set_device_id_for_pushers"
+            )
+
+        return nb_processed
+
+
+class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
     def get_pushers_stream_token(self) -> int:
         return self._pushers_id_gen.get_current_token()
 
@@ -496,6 +563,7 @@ class PusherStore(PusherWorkerStore):
         last_stream_ordering: int,
         profile_tag: str = "",
         enabled: bool = True,
+        device_id: Optional[str] = None,
     ) -> None:
         async with self._pushers_id_gen.get_next() as stream_id:
             # no need to lock because `pushers` has a unique key on
@@ -515,6 +583,7 @@ class PusherStore(PusherWorkerStore):
                     "profile_tag": profile_tag,
                     "id": stream_id,
                     "enabled": enabled,
+                    "device_id": device_id,
                 },
                 desc="add_pusher",
                 lock=False,
diff --git a/synapse/storage/schema/main/delta/73/03pusher_device_id.sql b/synapse/storage/schema/main/delta/73/03pusher_device_id.sql
new file mode 100644
index 0000000000..1b4ffbeebe
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/03pusher_device_id.sql
@@ -0,0 +1,20 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add a device_id column to track the device ID that created the pusher. It's NULLable
+-- on purpose, because a) it might not be possible to track down the device that created
+-- old pushers (pushers.access_token and access_tokens.device_id are both NULLable), and
+-- b) access tokens retrieved via the admin API don't have a device associated to them.
+ALTER TABLE pushers ADD COLUMN device_id TEXT;
\ No newline at end of file
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index af67d84463..b383b8401f 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -22,6 +22,7 @@ from synapse.logging.context import make_deferred_yieldable
 from synapse.push import PusherConfig, PusherConfigException
 from synapse.rest.client import login, push_rule, pusher, receipts, room
 from synapse.server import HomeServer
+from synapse.storage.databases.main.registration import TokenLookupResult
 from synapse.types import JsonDict
 from synapse.util import Clock
 
@@ -771,6 +772,7 @@ class HTTPPusherTests(HomeserverTestCase):
                 lang=None,
                 data={"url": "http://example.com/_matrix/push/v1/notify"},
                 enabled=enabled,
+                device_id=user_tuple.device_id,
             )
         )
 
@@ -885,19 +887,21 @@ class HTTPPusherTests(HomeserverTestCase):
         self.assertEqual(len(channel.json_body["pushers"]), 1)
         self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"])
 
-    def test_update_different_device_access_token(self) -> None:
+    def test_update_different_device_access_token_device_id(self) -> None:
         """Tests that if we create a pusher from one device, the update it from another
-        device, the access token associated with the pusher stays the same.
+        device, the access token and device ID associated with the pusher stays the
+        same.
         """
         # Create a user with a pusher.
         user_id, access_token = self._make_user_with_pusher("user")
 
         # Get the token ID for the current access token, since that's what we store in
-        # the pushers table.
+        # the pushers table. Also get the device ID from it.
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
         token_id = user_tuple.token_id
+        device_id = user_tuple.device_id
 
         # Generate a new access token, and update the pusher with it.
         new_token = self.login("user", "pass")
@@ -909,7 +913,48 @@ class HTTPPusherTests(HomeserverTestCase):
         )
         pushers: List[PusherConfig] = list(ret)
 
-        # Check that we still have one pusher, and that the access token associated with
-        # it didn't change.
+        # Check that we still have one pusher, and that the access token and device ID
+        # associated with it didn't change.
         self.assertEqual(len(pushers), 1)
         self.assertEqual(pushers[0].access_token, token_id)
+        self.assertEqual(pushers[0].device_id, device_id)
+
+    @override_config({"experimental_features": {"msc3881_enabled": True}})
+    def test_device_id(self) -> None:
+        """Tests that a pusher created with a given device ID shows that device ID in
+        GET /pushers requests.
+        """
+        self.register_user("user", "pass")
+        access_token = self.login("user", "pass")
+
+        # We create the pusher with an HTTP request rather than with
+        # _make_user_with_pusher so that we can test the device ID is correctly set when
+        # creating a pusher via an API call.
+        self.make_request(
+            method="POST",
+            path="/pushers/set",
+            content={
+                "kind": "http",
+                "app_id": "m.http",
+                "app_display_name": "HTTP Push Notifications",
+                "device_display_name": "pushy push",
+                "pushkey": "a@example.com",
+                "lang": "en",
+                "data": {"url": "http://example.com/_matrix/push/v1/notify"},
+            },
+            access_token=access_token,
+        )
+
+        # Look up the user info for the access token so we can compare the device ID.
+        lookup_result: TokenLookupResult = self.get_success(
+            self.hs.get_datastores().main.get_user_by_access_token(access_token)
+        )
+
+        # Get the user's devices and check it has the correct device ID.
+        channel = self.make_request("GET", "/pushers", access_token=access_token)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(len(channel.json_body["pushers"]), 1)
+        self.assertEqual(
+            channel.json_body["pushers"][0]["org.matrix.msc3881.device_id"],
+            lookup_result.device_id,
+        )