summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13799.feature1
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py1
-rw-r--r--synapse/config/experimental.py3
-rw-r--r--synapse/handlers/register.py4
-rw-r--r--synapse/push/__init__.py2
-rw-r--r--synapse/push/pusherpool.py81
-rw-r--r--synapse/replication/tcp/client.py10
-rw-r--r--synapse/rest/admin/users.py4
-rw-r--r--synapse/rest/client/pusher.py18
-rw-r--r--synapse/storage/databases/main/pusher.py69
-rw-r--r--synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql16
-rw-r--r--tests/push/test_email.py4
-rw-r--r--tests/push/test_http.py148
-rw-r--r--tests/replication/test_pusher_shard.py2
-rw-r--r--tests/rest/admin/test_user.py2
15 files changed, 294 insertions, 71 deletions
diff --git a/changelog.d/13799.feature b/changelog.d/13799.feature
new file mode 100644
index 0000000000..6c8e5cffe2
--- /dev/null
+++ b/changelog.d/13799.feature
@@ -0,0 +1 @@
+Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881).
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 30983c47fb..450ba462ba 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -111,6 +111,7 @@ BOOLEAN_COLUMNS = {
     "e2e_fallback_keys_json": ["used"],
     "access_tokens": ["used"],
     "device_lists_changes_in_room": ["converted_to_destinations"],
+    "pushers": ["enabled"],
 }
 
 
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 702b81e636..f4541a8db0 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -93,3 +93,6 @@ class ExperimentalConfig(Config):
 
         # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
         self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)
+
+        # MSC3881: Remotely toggle push notifications for another client
+        self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 20ec22105a..cfcadb34db 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -997,7 +997,7 @@ class RegistrationHandler:
             assert user_tuple
             token_id = user_tuple.token_id
 
-            await self.pusher_pool.add_pusher(
+            await self.pusher_pool.add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="email",
@@ -1005,7 +1005,7 @@ class RegistrationHandler:
                 app_display_name="Email Notifications",
                 device_display_name=threepid["address"],
                 pushkey=threepid["address"],
-                lang=None,  # We don't know a user's language here
+                lang=None,
                 data={},
             )
 
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 57c4d70466..ac99d35a7e 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -116,6 +116,7 @@ class PusherConfig:
     last_stream_ordering: int
     last_success: Optional[int]
     failing_since: Optional[int]
+    enabled: bool
 
     def as_dict(self) -> Dict[str, Any]:
         """Information that can be retrieved about a pusher after creation."""
@@ -128,6 +129,7 @@ class PusherConfig:
             "lang": self.lang,
             "profile_tag": self.profile_tag,
             "pushkey": self.pushkey,
+            "enabled": self.enabled,
         }
 
 
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 1e0ef44fc7..2597898cf4 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -94,7 +94,7 @@ class PusherPool:
             return
         run_as_background_process("start_pushers", self._start_pushers)
 
-    async def add_pusher(
+    async def add_or_update_pusher(
         self,
         user_id: str,
         access_token: Optional[int],
@@ -106,6 +106,7 @@ class PusherPool:
         lang: Optional[str],
         data: JsonDict,
         profile_tag: str = "",
+        enabled: bool = True,
     ) -> Optional[Pusher]:
         """Creates a new pusher and adds it to the pool
 
@@ -147,9 +148,20 @@ class PusherPool:
                 last_stream_ordering=last_stream_ordering,
                 last_success=None,
                 failing_since=None,
+                enabled=enabled,
             )
         )
 
+        # Before we actually persist the pusher, we check if the user already has one
+        # for this app ID and pushkey. If so, we want to keep the access token in place,
+        # since this could be one device modifying (e.g. enabling/disabling) another
+        # device's pusher.
+        existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey(
+            user_id, app_id, pushkey
+        )
+        if existing_config:
+            access_token = existing_config.access_token
+
         await self.store.add_pusher(
             user_id=user_id,
             access_token=access_token,
@@ -163,8 +175,9 @@ class PusherPool:
             data=data,
             last_stream_ordering=last_stream_ordering,
             profile_tag=profile_tag,
+            enabled=enabled,
         )
-        pusher = await self.start_pusher_by_id(app_id, pushkey, user_id)
+        pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id)
 
         return pusher
 
@@ -276,10 +289,25 @@ class PusherPool:
         except Exception:
             logger.exception("Exception in pusher on_new_receipts")
 
-    async def start_pusher_by_id(
+    async def _get_pusher_config_for_user_by_app_id_and_pushkey(
+        self, user_id: str, app_id: str, pushkey: str
+    ) -> Optional[PusherConfig]:
+        resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
+
+        pusher_config = None
+        for r in resultlist:
+            if r.user_name == user_id:
+                pusher_config = r
+
+        return pusher_config
+
+    async def process_pusher_change_by_id(
         self, app_id: str, pushkey: str, user_id: str
     ) -> Optional[Pusher]:
-        """Look up the details for the given pusher, and start it
+        """Look up the details for the given pusher, and either start it if its
+        "enabled" flag is True, or try to stop it otherwise.
+
+        If the pusher is new and its "enabled" flag is False, the stop is a noop.
 
         Returns:
             The pusher started, if any
@@ -290,12 +318,13 @@ class PusherPool:
         if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
             return None
 
-        resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
+        pusher_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey(
+            user_id, app_id, pushkey
+        )
 
-        pusher_config = None
-        for r in resultlist:
-            if r.user_name == user_id:
-                pusher_config = r
+        if pusher_config and not pusher_config.enabled:
+            self.maybe_stop_pusher(app_id, pushkey, user_id)
+            return None
 
         pusher = None
         if pusher_config:
@@ -305,7 +334,7 @@ class PusherPool:
 
     async def _start_pushers(self) -> None:
         """Start all the pushers"""
-        pushers = await self.store.get_all_pushers()
+        pushers = await self.store.get_enabled_pushers()
 
         # Stagger starting up the pushers so we don't completely drown the
         # process on start up.
@@ -363,6 +392,8 @@ class PusherPool:
 
         synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc()
 
+        logger.info("Starting pusher %s / %s", pusher.user_id, appid_pushkey)
+
         # Check if there *may* be push to process. We do this as this check is a
         # lot cheaper to do than actually fetching the exact rows we need to
         # push.
@@ -382,16 +413,7 @@ class PusherPool:
         return pusher
 
     async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
-        appid_pushkey = "%s:%s" % (app_id, pushkey)
-
-        byuser = self.pushers.get(user_id, {})
-
-        if appid_pushkey in byuser:
-            logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
-            pusher = byuser.pop(appid_pushkey)
-            pusher.on_stop()
-
-            synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
+        self.maybe_stop_pusher(app_id, pushkey, user_id)
 
         # We can only delete pushers on master.
         if self._remove_pusher_client:
@@ -402,3 +424,22 @@ class PusherPool:
             await self.store.delete_pusher_by_app_id_pushkey_user_id(
                 app_id, pushkey, user_id
             )
+
+    def maybe_stop_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
+        """Stops a pusher with the given app ID and push key if one is running.
+
+        Args:
+            app_id: the pusher's app ID.
+            pushkey: the pusher's push key.
+            user_id: the user the pusher belongs to. Only used for logging.
+        """
+        appid_pushkey = "%s:%s" % (app_id, pushkey)
+
+        byuser = self.pushers.get(user_id, {})
+
+        if appid_pushkey in byuser:
+            logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
+            pusher = byuser.pop(appid_pushkey)
+            pusher.on_stop()
+
+            synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e4f2201c92..cf9cd6833b 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -189,7 +189,9 @@ class ReplicationDataHandler:
                 if row.deleted:
                     self.stop_pusher(row.user_id, row.app_id, row.pushkey)
                 else:
-                    await self.start_pusher(row.user_id, row.app_id, row.pushkey)
+                    await self.process_pusher_change(
+                        row.user_id, row.app_id, row.pushkey
+                    )
         elif stream_name == EventsStream.NAME:
             # We shouldn't get multiple rows per token for events stream, so
             # we don't need to optimise this for multiple rows.
@@ -334,13 +336,15 @@ class ReplicationDataHandler:
         logger.info("Stopping pusher %r / %r", user_id, key)
         pusher.on_stop()
 
-    async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
+    async def process_pusher_change(
+        self, user_id: str, app_id: str, pushkey: str
+    ) -> None:
         if not self._notify_pushers:
             return
 
         key = "%s:%s" % (app_id, pushkey)
         logger.info("Starting pusher %r / %r", user_id, key)
-        await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
+        await self._pusher_pool.process_pusher_change_by_id(app_id, pushkey, user_id)
 
 
 class FederationSenderHandler:
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 2ca6b2d08a..1274773d7e 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -375,7 +375,7 @@ class UserRestServletV2(RestServlet):
                         and self.hs.config.email.email_notif_for_new_users
                         and medium == "email"
                     ):
-                        await self.pusher_pool.add_pusher(
+                        await self.pusher_pool.add_or_update_pusher(
                             user_id=user_id,
                             access_token=None,
                             kind="email",
@@ -383,7 +383,7 @@ class UserRestServletV2(RestServlet):
                             app_display_name="Email Notifications",
                             device_display_name=address,
                             pushkey=address,
-                            lang=None,  # We don't know a user's language here
+                            lang=None,
                             data={},
                         )
 
diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py
index 9a1f10f4be..c9f76125dc 100644
--- a/synapse/rest/client/pusher.py
+++ b/synapse/rest/client/pusher.py
@@ -42,6 +42,7 @@ class PushersRestServlet(RestServlet):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
+        self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
@@ -51,9 +52,14 @@ class PushersRestServlet(RestServlet):
             user.to_string()
         )
 
-        filtered_pushers = [p.as_dict() for p in pushers]
+        pusher_dicts = [p.as_dict() for p in pushers]
 
-        return 200, {"pushers": filtered_pushers}
+        for pusher in pusher_dicts:
+            if self._msc3881_enabled:
+                pusher["org.matrix.msc3881.enabled"] = pusher["enabled"]
+            del pusher["enabled"]
+
+        return 200, {"pushers": pusher_dicts}
 
 
 class PushersSetRestServlet(RestServlet):
@@ -65,6 +71,7 @@ class PushersSetRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.notifier = hs.get_notifier()
         self.pusher_pool = self.hs.get_pusherpool()
+        self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
@@ -103,6 +110,10 @@ class PushersSetRestServlet(RestServlet):
         if "append" in content:
             append = content["append"]
 
+        enabled = True
+        if self._msc3881_enabled and "org.matrix.msc3881.enabled" in content:
+            enabled = content["org.matrix.msc3881.enabled"]
+
         if not append:
             await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
                 app_id=content["app_id"],
@@ -111,7 +122,7 @@ class PushersSetRestServlet(RestServlet):
             )
 
         try:
-            await self.pusher_pool.add_pusher(
+            await self.pusher_pool.add_or_update_pusher(
                 user_id=user.to_string(),
                 access_token=requester.access_token_id,
                 kind=content["kind"],
@@ -122,6 +133,7 @@ class PushersSetRestServlet(RestServlet):
                 lang=content["lang"],
                 data=content["data"],
                 profile_tag=content.get("profile_tag", ""),
+                enabled=enabled,
             )
         except PusherConfigException as pce:
             raise SynapseError(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index bd0cfa7f32..ee55b8c4a9 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -89,6 +89,11 @@ class PusherWorkerStore(SQLBaseStore):
                 )
                 continue
 
+            # If we're using SQLite, then boolean values are integers. This is
+            # troublesome since some code using the return value of this method might
+            # expect it to be a boolean, or will expose it to clients (in responses).
+            r["enabled"] = bool(r["enabled"])
+
             yield PusherConfig(**r)
 
     async def get_pushers_by_app_id_and_pushkey(
@@ -100,38 +105,52 @@ class PusherWorkerStore(SQLBaseStore):
         return await self.get_pushers_by({"user_name": user_id})
 
     async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]:
-        ret = await self.db_pool.simple_select_list(
-            "pushers",
-            keyvalues,
-            [
-                "id",
-                "user_name",
-                "access_token",
-                "profile_tag",
-                "kind",
-                "app_id",
-                "app_display_name",
-                "device_display_name",
-                "pushkey",
-                "ts",
-                "lang",
-                "data",
-                "last_stream_ordering",
-                "last_success",
-                "failing_since",
-            ],
+        """Retrieve pushers that match the given criteria.
+
+        Args:
+            keyvalues: A {column: value} dictionary.
+
+        Returns:
+            The pushers for which the given columns have the given values.
+        """
+
+        def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
+            # We could technically use simple_select_list here, but we need to call
+            # COALESCE on the 'enabled' column. While it is technically possible to give
+            # simple_select_list the whole `COALESCE(...) AS ...` as a column name, it
+            # feels a bit hacky, so it's probably better to just inline the query.
+            sql = """
+            SELECT
+                id, user_name, access_token, profile_tag, kind, app_id,
+                app_display_name, device_display_name, pushkey, ts, lang, data,
+                last_stream_ordering, last_success, failing_since,
+                COALESCE(enabled, TRUE) AS enabled
+            FROM pushers
+            """
+
+            sql += "WHERE %s" % (" AND ".join("%s = ?" % (k,) for k in keyvalues),)
+
+            txn.execute(sql, list(keyvalues.values()))
+
+            return self.db_pool.cursor_to_dict(txn)
+
+        ret = await self.db_pool.runInteraction(
             desc="get_pushers_by",
+            func=get_pushers_by_txn,
         )
+
         return self._decode_pushers_rows(ret)
 
-    async def get_all_pushers(self) -> Iterator[PusherConfig]:
-        def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]:
-            txn.execute("SELECT * FROM pushers")
+    async def get_enabled_pushers(self) -> Iterator[PusherConfig]:
+        def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]:
+            txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)")
             rows = self.db_pool.cursor_to_dict(txn)
 
             return self._decode_pushers_rows(rows)
 
-        return await self.db_pool.runInteraction("get_all_pushers", get_pushers)
+        return await self.db_pool.runInteraction(
+            "get_enabled_pushers", get_enabled_pushers_txn
+        )
 
     async def get_all_updated_pushers_rows(
         self, instance_name: str, last_id: int, current_id: int, limit: int
@@ -476,6 +495,7 @@ class PusherStore(PusherWorkerStore):
         data: Optional[JsonDict],
         last_stream_ordering: int,
         profile_tag: str = "",
+        enabled: bool = True,
     ) -> None:
         async with self._pushers_id_gen.get_next() as stream_id:
             # no need to lock because `pushers` has a unique key on
@@ -494,6 +514,7 @@ class PusherStore(PusherWorkerStore):
                     "last_stream_ordering": last_stream_ordering,
                     "profile_tag": profile_tag,
                     "id": stream_id,
+                    "enabled": enabled,
                 },
                 desc="add_pusher",
                 lock=False,
diff --git a/synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql b/synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql
new file mode 100644
index 0000000000..dba3b4900b
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql
@@ -0,0 +1,16 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE pushers ADD COLUMN enabled BOOLEAN;
\ No newline at end of file
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 7a3b0d6755..fd14568f55 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -114,7 +114,7 @@ class EmailPusherTests(HomeserverTestCase):
         )
 
         self.pusher = self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=self.user_id,
                 access_token=self.token_id,
                 kind="email",
@@ -136,7 +136,7 @@ class EmailPusherTests(HomeserverTestCase):
         """
         with self.assertRaises(SynapseError) as cm:
             self.get_success_or_raise(
-                self.hs.get_pusherpool().add_pusher(
+                self.hs.get_pusherpool().add_or_update_pusher(
                     user_id=self.user_id,
                     access_token=self.token_id,
                     kind="email",
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index d9c68cdd2d..af67d84463 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -19,8 +19,8 @@ from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
 from synapse.logging.context import make_deferred_yieldable
-from synapse.push import PusherConfigException
-from synapse.rest.client import login, push_rule, receipts, room
+from synapse.push import PusherConfig, PusherConfigException
+from synapse.rest.client import login, push_rule, pusher, receipts, room
 from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
@@ -35,6 +35,7 @@ class HTTPPusherTests(HomeserverTestCase):
         login.register_servlets,
         receipts.register_servlets,
         push_rule.register_servlets,
+        pusher.register_servlets,
     ]
     user_id = True
     hijack_auth = False
@@ -74,7 +75,7 @@ class HTTPPusherTests(HomeserverTestCase):
 
         def test_data(data: Optional[JsonDict]) -> None:
             self.get_failure(
-                self.hs.get_pusherpool().add_pusher(
+                self.hs.get_pusherpool().add_or_update_pusher(
                     user_id=user_id,
                     access_token=token_id,
                     kind="http",
@@ -119,7 +120,7 @@ class HTTPPusherTests(HomeserverTestCase):
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
@@ -235,7 +236,7 @@ class HTTPPusherTests(HomeserverTestCase):
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
@@ -355,7 +356,7 @@ class HTTPPusherTests(HomeserverTestCase):
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
@@ -441,7 +442,7 @@ class HTTPPusherTests(HomeserverTestCase):
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
@@ -518,7 +519,7 @@ class HTTPPusherTests(HomeserverTestCase):
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
@@ -624,7 +625,7 @@ class HTTPPusherTests(HomeserverTestCase):
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
@@ -728,18 +729,38 @@ class HTTPPusherTests(HomeserverTestCase):
         )
         self.assertEqual(channel.code, 200, channel.json_body)
 
-    def _make_user_with_pusher(self, username: str) -> Tuple[str, str]:
+    def _make_user_with_pusher(
+        self, username: str, enabled: bool = True
+    ) -> Tuple[str, str]:
+        """Registers a user and creates a pusher for them.
+
+        Args:
+            username: the localpart of the new user's Matrix ID.
+            enabled: whether to create the pusher in an enabled or disabled state.
+        """
         user_id = self.register_user(username, "pass")
         access_token = self.login(username, "pass")
 
         # Register the pusher
+        self._set_pusher(user_id, access_token, enabled)
+
+        return user_id, access_token
+
+    def _set_pusher(self, user_id: str, access_token: str, enabled: bool) -> None:
+        """Creates or updates the pusher for the given user.
+
+        Args:
+            user_id: the user's Matrix ID.
+            access_token: the access token associated with the pusher.
+            enabled: whether to enable or disable the pusher.
+        """
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
@@ -749,11 +770,10 @@ class HTTPPusherTests(HomeserverTestCase):
                 pushkey="a@example.com",
                 lang=None,
                 data={"url": "http://example.com/_matrix/push/v1/notify"},
+                enabled=enabled,
             )
         )
 
-        return user_id, access_token
-
     def test_dont_notify_rule_overrides_message(self) -> None:
         """
         The override push rule will suppress notification
@@ -791,3 +811,105 @@ class HTTPPusherTests(HomeserverTestCase):
         # The user sends a message back (sends a notification)
         self.helper.send(room, body="Hello", tok=access_token)
         self.assertEqual(len(self.push_attempts), 1)
+
+    @override_config({"experimental_features": {"msc3881_enabled": True}})
+    def test_disable(self) -> None:
+        """Tests that disabling a pusher means it's not pushed to anymore."""
+        user_id, access_token = self._make_user_with_pusher("user")
+        other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
+
+        room = self.helper.create_room_as(user_id, tok=access_token)
+        self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+        # Send a message and check that it generated a push.
+        self.helper.send(room, body="Hi!", tok=other_access_token)
+        self.assertEqual(len(self.push_attempts), 1)
+
+        # Disable the pusher.
+        self._set_pusher(user_id, access_token, enabled=False)
+
+        # Send another message and check that it did not generate a push.
+        self.helper.send(room, body="Hi!", tok=other_access_token)
+        self.assertEqual(len(self.push_attempts), 1)
+
+        # Get the pushers for the user and check that it is marked as disabled.
+        channel = self.make_request("GET", "/pushers", access_token=access_token)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(len(channel.json_body["pushers"]), 1)
+
+        enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]
+        self.assertFalse(enabled)
+        self.assertTrue(isinstance(enabled, bool))
+
+    @override_config({"experimental_features": {"msc3881_enabled": True}})
+    def test_enable(self) -> None:
+        """Tests that enabling a disabled pusher means it gets pushed to."""
+        # Create the user with the pusher already disabled.
+        user_id, access_token = self._make_user_with_pusher("user", enabled=False)
+        other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
+
+        room = self.helper.create_room_as(user_id, tok=access_token)
+        self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+        # Send a message and check that it did not generate a push.
+        self.helper.send(room, body="Hi!", tok=other_access_token)
+        self.assertEqual(len(self.push_attempts), 0)
+
+        # Enable the pusher.
+        self._set_pusher(user_id, access_token, enabled=True)
+
+        # Send another message and check that it did generate a push.
+        self.helper.send(room, body="Hi!", tok=other_access_token)
+        self.assertEqual(len(self.push_attempts), 1)
+
+        # Get the pushers for the user and check that it is marked as enabled.
+        channel = self.make_request("GET", "/pushers", access_token=access_token)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(len(channel.json_body["pushers"]), 1)
+
+        enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]
+        self.assertTrue(enabled)
+        self.assertTrue(isinstance(enabled, bool))
+
+    @override_config({"experimental_features": {"msc3881_enabled": True}})
+    def test_null_enabled(self) -> None:
+        """Tests that a pusher that has an 'enabled' column set to NULL (eg pushers
+        created before the column was introduced) is considered enabled.
+        """
+        # We intentionally set 'enabled' to None so that it's stored as NULL in the
+        # database.
+        user_id, access_token = self._make_user_with_pusher("user", enabled=None)  # type: ignore[arg-type]
+
+        channel = self.make_request("GET", "/pushers", access_token=access_token)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(len(channel.json_body["pushers"]), 1)
+        self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"])
+
+    def test_update_different_device_access_token(self) -> None:
+        """Tests that if we create a pusher from one device, the update it from another
+        device, the access token associated with the pusher stays the same.
+        """
+        # Create a user with a pusher.
+        user_id, access_token = self._make_user_with_pusher("user")
+
+        # Get the token ID for the current access token, since that's what we store in
+        # the pushers table.
+        user_tuple = self.get_success(
+            self.hs.get_datastores().main.get_user_by_access_token(access_token)
+        )
+        token_id = user_tuple.token_id
+
+        # Generate a new access token, and update the pusher with it.
+        new_token = self.login("user", "pass")
+        self._set_pusher(user_id, new_token, enabled=False)
+
+        # Get the current list of pushers for the user.
+        ret = self.get_success(
+            self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+        )
+        pushers: List[PusherConfig] = list(ret)
+
+        # Check that we still have one pusher, and that the access token associated with
+        # it didn't change.
+        self.assertEqual(len(pushers), 1)
+        self.assertEqual(pushers[0].access_token, token_id)
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 8f4f6688ce..59fea93e49 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         token_id = user_dict.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 9f536ceeb3..1847e6ad6b 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -2839,7 +2839,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=self.other_user,
                 access_token=token_id,
                 kind="http",