summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/federation/test_federation_catch_up.py49
-rw-r--r--tests/rest/admin/test_user.py173
2 files changed, 167 insertions, 55 deletions
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 6f96cd7940..95eac6a5a3 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -2,6 +2,7 @@ from typing import List, Tuple
 
 from mock import Mock
 
+from synapse.api.constants import EventTypes
 from synapse.events import EventBase
 from synapse.federation.sender import PerDestinationQueue, TransactionManager
 from synapse.federation.units import Edu
@@ -421,3 +422,51 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         self.assertNotIn("zzzerver", woken)
         # - all destinations are woken exactly once; they appear once in woken.
         self.assertCountEqual(woken, server_names[:-1])
+
+    @override_config({"send_federation": True})
+    def test_not_latest_event(self):
+        """Test that we send the latest event in the room even if its not ours."""
+
+        per_dest_queue, sent_pdus = self.make_fake_destination_queue()
+
+        # Make a room with a local user, and two servers. One will go offline
+        # and one will send some events.
+        self.register_user("u1", "you the one")
+        u1_token = self.login("u1", "you the one")
+        room_1 = self.helper.create_room_as("u1", tok=u1_token)
+
+        self.get_success(
+            event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join")
+        )
+        event_1 = self.get_success(
+            event_injection.inject_member_event(self.hs, room_1, "@user:host3", "join")
+        )
+
+        # First we send something from the local server, so that we notice the
+        # remote is down and go into catchup mode.
+        self.helper.send(room_1, "you hear me!!", tok=u1_token)
+
+        # Now simulate us receiving an event from the still online remote.
+        event_2 = self.get_success(
+            event_injection.inject_event(
+                self.hs,
+                type=EventTypes.Message,
+                sender="@user:host3",
+                room_id=room_1,
+                content={"msgtype": "m.text", "body": "Hello"},
+            )
+        )
+
+        self.get_success(
+            self.hs.get_datastore().set_destination_last_successful_stream_ordering(
+                "host2", event_1.internal_metadata.stream_ordering
+            )
+        )
+
+        self.get_success(per_dest_queue._catch_up_transmission_loop())
+
+        # We expect only the last message from the remote, event_2, to have been
+        # sent, rather than the last *local* event that was sent.
+        self.assertEqual(len(sent_pdus), 1)
+        self.assertEqual(sent_pdus[0].event_id, event_2.event_id)
+        self.assertFalse(per_dest_queue._catching_up)
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index e58d5cf0db..cf61f284cb 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1003,12 +1003,23 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
+        self.auth_handler = hs.get_auth_handler()
 
+        # create users and get access tokens
+        # regardless of whether password login or SSO is allowed
         self.admin_user = self.register_user("admin", "pass", admin=True)
-        self.admin_user_tok = self.login("admin", "pass")
+        self.admin_user_tok = self.get_success(
+            self.auth_handler.get_access_token_for_user_id(
+                self.admin_user, device_id=None, valid_until_ms=None
+            )
+        )
 
         self.other_user = self.register_user("user", "pass", displayname="User")
-        self.other_user_token = self.login("user", "pass")
+        self.other_user_token = self.get_success(
+            self.auth_handler.get_access_token_for_user_id(
+                self.other_user, device_id=None, valid_until_ms=None
+            )
+        )
         self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
             self.other_user
         )
@@ -1081,7 +1092,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("Bob's name", channel.json_body["displayname"])
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
-        self.assertEqual(True, channel.json_body["admin"])
+        self.assertTrue(channel.json_body["admin"])
         self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
         # Get user
@@ -1096,9 +1107,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("Bob's name", channel.json_body["displayname"])
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
-        self.assertEqual(True, channel.json_body["admin"])
-        self.assertEqual(False, channel.json_body["is_guest"])
-        self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertTrue(channel.json_body["admin"])
+        self.assertFalse(channel.json_body["is_guest"])
+        self.assertFalse(channel.json_body["deactivated"])
         self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
     def test_create_user(self):
@@ -1130,7 +1141,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("Bob's name", channel.json_body["displayname"])
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
-        self.assertEqual(False, channel.json_body["admin"])
+        self.assertFalse(channel.json_body["admin"])
         self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
         # Get user
@@ -1145,10 +1156,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("Bob's name", channel.json_body["displayname"])
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
-        self.assertEqual(False, channel.json_body["admin"])
-        self.assertEqual(False, channel.json_body["is_guest"])
-        self.assertEqual(False, channel.json_body["deactivated"])
-        self.assertEqual(False, channel.json_body["shadow_banned"])
+        self.assertFalse(channel.json_body["admin"])
+        self.assertFalse(channel.json_body["is_guest"])
+        self.assertFalse(channel.json_body["deactivated"])
+        self.assertFalse(channel.json_body["shadow_banned"])
         self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
     @override_config(
@@ -1197,7 +1208,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["name"])
-        self.assertEqual(False, channel.json_body["admin"])
+        self.assertFalse(channel.json_body["admin"])
 
     @override_config(
         {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
@@ -1237,7 +1248,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Admin user is not blocked by mau anymore
         self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["name"])
-        self.assertEqual(False, channel.json_body["admin"])
+        self.assertFalse(channel.json_body["admin"])
 
     @override_config(
         {
@@ -1429,24 +1440,23 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertFalse(channel.json_body["deactivated"])
         self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
         self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
         self.assertEqual("User", channel.json_body["displayname"])
 
         # Deactivate user
-        body = json.dumps({"deactivated": True})
-
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"deactivated": True},
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertTrue(channel.json_body["deactivated"])
+        self.assertIsNone(channel.json_body["password_hash"])
         self.assertEqual(0, len(channel.json_body["threepids"]))
         self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
         self.assertEqual("User", channel.json_body["displayname"])
@@ -1461,7 +1471,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertTrue(channel.json_body["deactivated"])
+        self.assertIsNone(channel.json_body["password_hash"])
         self.assertEqual(0, len(channel.json_body["threepids"]))
         self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
         self.assertEqual("User", channel.json_body["displayname"])
@@ -1478,41 +1489,37 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertTrue(profile["display_name"] == "User")
 
         # Deactivate user
-        body = json.dumps({"deactivated": True})
-
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"deactivated": True},
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertTrue(channel.json_body["deactivated"])
 
         # is not in user directory
         profile = self.get_success(self.store.get_user_in_directory(self.other_user))
-        self.assertTrue(profile is None)
+        self.assertIsNone(profile)
 
         # Set new displayname user
-        body = json.dumps({"displayname": "Foobar"})
-
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"displayname": "Foobar"},
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertTrue(channel.json_body["deactivated"])
         self.assertEqual("Foobar", channel.json_body["displayname"])
 
         # is not in user directory
         profile = self.get_success(self.store.get_user_in_directory(self.other_user))
-        self.assertTrue(profile is None)
+        self.assertIsNone(profile)
 
     def test_reactivate_user(self):
         """
@@ -1520,48 +1527,92 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         """
 
         # Deactivate the user.
+        self._deactivate_user("@user:test")
+
+        # Attempt to reactivate the user (without a password).
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content={"deactivated": False},
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Reactivate the user.
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content=json.dumps({"deactivated": True}).encode(encoding="utf_8"),
+            content={"deactivated": False, "password": "foo"},
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertFalse(channel.json_body["deactivated"])
+        self.assertIsNotNone(channel.json_body["password_hash"])
         self._is_erased("@user:test", False)
-        d = self.store.mark_user_erased("@user:test")
-        self.assertIsNone(self.get_success(d))
-        self._is_erased("@user:test", True)
 
-        # Attempt to reactivate the user (without a password).
+    @override_config({"password_config": {"localdb_enabled": False}})
+    def test_reactivate_user_localdb_disabled(self):
+        """
+        Test reactivating another user when using SSO.
+        """
+
+        # Deactivate the user.
+        self._deactivate_user("@user:test")
+
+        # Reactivate the user with a password
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content=json.dumps({"deactivated": False}).encode(encoding="utf_8"),
+            content={"deactivated": False, "password": "foo"},
         )
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-        # Reactivate the user.
+        # Reactivate the user without a password.
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content=json.dumps({"deactivated": False, "password": "foo"}).encode(
-                encoding="utf_8"
-            ),
+            content={"deactivated": False},
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertFalse(channel.json_body["deactivated"])
+        self.assertIsNone(channel.json_body["password_hash"])
+        self._is_erased("@user:test", False)
 
-        # Get user
+    @override_config({"password_config": {"enabled": False}})
+    def test_reactivate_user_password_disabled(self):
+        """
+        Test reactivating another user when using SSO.
+        """
+
+        # Deactivate the user.
+        self._deactivate_user("@user:test")
+
+        # Reactivate the user with a password
         channel = self.make_request(
-            "GET",
+            "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
+            content={"deactivated": False, "password": "foo"},
         )
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
+        # Reactivate the user without a password.
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content={"deactivated": False},
+        )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertFalse(channel.json_body["deactivated"])
+        self.assertIsNone(channel.json_body["password_hash"])
         self._is_erased("@user:test", False)
 
     def test_set_user_as_admin(self):
@@ -1570,18 +1621,16 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         """
 
         # Set a user as an admin
-        body = json.dumps({"admin": True})
-
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"admin": True},
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(True, channel.json_body["admin"])
+        self.assertTrue(channel.json_body["admin"])
 
         # Get user
         channel = self.make_request(
@@ -1592,7 +1641,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(True, channel.json_body["admin"])
+        self.assertTrue(channel.json_body["admin"])
 
     def test_accidental_deactivation_prevention(self):
         """
@@ -1602,13 +1651,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         url = "/_synapse/admin/v2/users/@bob:test"
 
         # Create user
-        body = json.dumps({"password": "abc123"})
-
         channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"password": "abc123"},
         )
 
         self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
@@ -1628,13 +1675,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(0, channel.json_body["deactivated"])
 
         # Change password (and use a str for deactivate instead of a bool)
-        body = json.dumps({"password": "abc123", "deactivated": "false"})  # oops!
-
         channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"password": "abc123", "deactivated": "false"},
         )
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -1653,7 +1698,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Ensure they're still alive
         self.assertEqual(0, channel.json_body["deactivated"])
 
-    def _is_erased(self, user_id, expect):
+    def _is_erased(self, user_id: str, expect: bool) -> None:
         """Assert that the user is erased or not"""
         d = self.store.is_user_erased(user_id)
         if expect:
@@ -1661,6 +1706,24 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         else:
             self.assertFalse(self.get_success(d))
 
+    def _deactivate_user(self, user_id: str) -> None:
+        """Deactivate user and set as erased"""
+
+        # Deactivate the user.
+        channel = self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/%s" % urllib.parse.quote(user_id),
+            access_token=self.admin_user_tok,
+            content={"deactivated": True},
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertTrue(channel.json_body["deactivated"])
+        self.assertIsNone(channel.json_body["password_hash"])
+        self._is_erased(user_id, False)
+        d = self.store.mark_user_erased(user_id)
+        self.assertIsNone(self.get_success(d))
+        self._is_erased(user_id, True)
+
 
 class UserMembershipRestTestCase(unittest.HomeserverTestCase):