summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/push/test_email.py4
-rw-r--r--tests/push/test_http.py148
-rw-r--r--tests/replication/test_pusher_shard.py2
-rw-r--r--tests/rest/admin/test_user.py2
4 files changed, 139 insertions, 17 deletions
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 7a3b0d6755..fd14568f55 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -114,7 +114,7 @@ class EmailPusherTests(HomeserverTestCase):
         )
 
         self.pusher = self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=self.user_id,
                 access_token=self.token_id,
                 kind="email",
@@ -136,7 +136,7 @@ class EmailPusherTests(HomeserverTestCase):
         """
         with self.assertRaises(SynapseError) as cm:
             self.get_success_or_raise(
-                self.hs.get_pusherpool().add_pusher(
+                self.hs.get_pusherpool().add_or_update_pusher(
                     user_id=self.user_id,
                     access_token=self.token_id,
                     kind="email",
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index d9c68cdd2d..af67d84463 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -19,8 +19,8 @@ from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
 from synapse.logging.context import make_deferred_yieldable
-from synapse.push import PusherConfigException
-from synapse.rest.client import login, push_rule, receipts, room
+from synapse.push import PusherConfig, PusherConfigException
+from synapse.rest.client import login, push_rule, pusher, receipts, room
 from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
@@ -35,6 +35,7 @@ class HTTPPusherTests(HomeserverTestCase):
         login.register_servlets,
         receipts.register_servlets,
         push_rule.register_servlets,
+        pusher.register_servlets,
     ]
     user_id = True
     hijack_auth = False
@@ -74,7 +75,7 @@ class HTTPPusherTests(HomeserverTestCase):
 
         def test_data(data: Optional[JsonDict]) -> None:
             self.get_failure(
-                self.hs.get_pusherpool().add_pusher(
+                self.hs.get_pusherpool().add_or_update_pusher(
                     user_id=user_id,
                     access_token=token_id,
                     kind="http",
@@ -119,7 +120,7 @@ class HTTPPusherTests(HomeserverTestCase):
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
@@ -235,7 +236,7 @@ class HTTPPusherTests(HomeserverTestCase):
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
@@ -355,7 +356,7 @@ class HTTPPusherTests(HomeserverTestCase):
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
@@ -441,7 +442,7 @@ class HTTPPusherTests(HomeserverTestCase):
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
@@ -518,7 +519,7 @@ class HTTPPusherTests(HomeserverTestCase):
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
@@ -624,7 +625,7 @@ class HTTPPusherTests(HomeserverTestCase):
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
@@ -728,18 +729,38 @@ class HTTPPusherTests(HomeserverTestCase):
         )
         self.assertEqual(channel.code, 200, channel.json_body)
 
-    def _make_user_with_pusher(self, username: str) -> Tuple[str, str]:
+    def _make_user_with_pusher(
+        self, username: str, enabled: bool = True
+    ) -> Tuple[str, str]:
+        """Registers a user and creates a pusher for them.
+
+        Args:
+            username: the localpart of the new user's Matrix ID.
+            enabled: whether to create the pusher in an enabled or disabled state.
+        """
         user_id = self.register_user(username, "pass")
         access_token = self.login(username, "pass")
 
         # Register the pusher
+        self._set_pusher(user_id, access_token, enabled)
+
+        return user_id, access_token
+
+    def _set_pusher(self, user_id: str, access_token: str, enabled: bool) -> None:
+        """Creates or updates the pusher for the given user.
+
+        Args:
+            user_id: the user's Matrix ID.
+            access_token: the access token associated with the pusher.
+            enabled: whether to enable or disable the pusher.
+        """
         user_tuple = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(access_token)
         )
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
@@ -749,11 +770,10 @@ class HTTPPusherTests(HomeserverTestCase):
                 pushkey="a@example.com",
                 lang=None,
                 data={"url": "http://example.com/_matrix/push/v1/notify"},
+                enabled=enabled,
             )
         )
 
-        return user_id, access_token
-
     def test_dont_notify_rule_overrides_message(self) -> None:
         """
         The override push rule will suppress notification
@@ -791,3 +811,105 @@ class HTTPPusherTests(HomeserverTestCase):
         # The user sends a message back (sends a notification)
         self.helper.send(room, body="Hello", tok=access_token)
         self.assertEqual(len(self.push_attempts), 1)
+
+    @override_config({"experimental_features": {"msc3881_enabled": True}})
+    def test_disable(self) -> None:
+        """Tests that disabling a pusher means it's not pushed to anymore."""
+        user_id, access_token = self._make_user_with_pusher("user")
+        other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
+
+        room = self.helper.create_room_as(user_id, tok=access_token)
+        self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+        # Send a message and check that it generated a push.
+        self.helper.send(room, body="Hi!", tok=other_access_token)
+        self.assertEqual(len(self.push_attempts), 1)
+
+        # Disable the pusher.
+        self._set_pusher(user_id, access_token, enabled=False)
+
+        # Send another message and check that it did not generate a push.
+        self.helper.send(room, body="Hi!", tok=other_access_token)
+        self.assertEqual(len(self.push_attempts), 1)
+
+        # Get the pushers for the user and check that it is marked as disabled.
+        channel = self.make_request("GET", "/pushers", access_token=access_token)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(len(channel.json_body["pushers"]), 1)
+
+        enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]
+        self.assertFalse(enabled)
+        self.assertTrue(isinstance(enabled, bool))
+
+    @override_config({"experimental_features": {"msc3881_enabled": True}})
+    def test_enable(self) -> None:
+        """Tests that enabling a disabled pusher means it gets pushed to."""
+        # Create the user with the pusher already disabled.
+        user_id, access_token = self._make_user_with_pusher("user", enabled=False)
+        other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
+
+        room = self.helper.create_room_as(user_id, tok=access_token)
+        self.helper.join(room=room, user=other_user_id, tok=other_access_token)
+
+        # Send a message and check that it did not generate a push.
+        self.helper.send(room, body="Hi!", tok=other_access_token)
+        self.assertEqual(len(self.push_attempts), 0)
+
+        # Enable the pusher.
+        self._set_pusher(user_id, access_token, enabled=True)
+
+        # Send another message and check that it did generate a push.
+        self.helper.send(room, body="Hi!", tok=other_access_token)
+        self.assertEqual(len(self.push_attempts), 1)
+
+        # Get the pushers for the user and check that it is marked as enabled.
+        channel = self.make_request("GET", "/pushers", access_token=access_token)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(len(channel.json_body["pushers"]), 1)
+
+        enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]
+        self.assertTrue(enabled)
+        self.assertTrue(isinstance(enabled, bool))
+
+    @override_config({"experimental_features": {"msc3881_enabled": True}})
+    def test_null_enabled(self) -> None:
+        """Tests that a pusher that has an 'enabled' column set to NULL (eg pushers
+        created before the column was introduced) is considered enabled.
+        """
+        # We intentionally set 'enabled' to None so that it's stored as NULL in the
+        # database.
+        user_id, access_token = self._make_user_with_pusher("user", enabled=None)  # type: ignore[arg-type]
+
+        channel = self.make_request("GET", "/pushers", access_token=access_token)
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(len(channel.json_body["pushers"]), 1)
+        self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"])
+
+    def test_update_different_device_access_token(self) -> None:
+        """Tests that if we create a pusher from one device, the update it from another
+        device, the access token associated with the pusher stays the same.
+        """
+        # Create a user with a pusher.
+        user_id, access_token = self._make_user_with_pusher("user")
+
+        # Get the token ID for the current access token, since that's what we store in
+        # the pushers table.
+        user_tuple = self.get_success(
+            self.hs.get_datastores().main.get_user_by_access_token(access_token)
+        )
+        token_id = user_tuple.token_id
+
+        # Generate a new access token, and update the pusher with it.
+        new_token = self.login("user", "pass")
+        self._set_pusher(user_id, new_token, enabled=False)
+
+        # Get the current list of pushers for the user.
+        ret = self.get_success(
+            self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
+        )
+        pushers: List[PusherConfig] = list(ret)
+
+        # Check that we still have one pusher, and that the access token associated with
+        # it didn't change.
+        self.assertEqual(len(pushers), 1)
+        self.assertEqual(pushers[0].access_token, token_id)
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 8f4f6688ce..59fea93e49 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         token_id = user_dict.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=user_id,
                 access_token=token_id,
                 kind="http",
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 9f536ceeb3..1847e6ad6b 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -2839,7 +2839,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
         token_id = user_tuple.token_id
 
         self.get_success(
-            self.hs.get_pusherpool().add_pusher(
+            self.hs.get_pusherpool().add_or_update_pusher(
                 user_id=self.other_user,
                 access_token=token_id,
                 kind="http",