summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/federation/test_federation_catch_up.py49
-rw-r--r--tests/handlers/test_oidc.py132
-rw-r--r--tests/handlers/test_presence.py20
-rw-r--r--tests/http/test_proxyagent.py40
-rw-r--r--tests/rest/admin/test_user.py173
-rw-r--r--tests/rest/client/test_third_party_rules.py62
-rw-r--r--tests/rest/client/v2_alpha/test_capabilities.py36
-rw-r--r--tests/rest/client/v2_alpha/test_relations.py62
-rw-r--r--tests/server.py2
-rw-r--r--tests/unittest.py12
10 files changed, 527 insertions, 61 deletions
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 6f96cd7940..95eac6a5a3 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -2,6 +2,7 @@ from typing import List, Tuple
 
 from mock import Mock
 
+from synapse.api.constants import EventTypes
 from synapse.events import EventBase
 from synapse.federation.sender import PerDestinationQueue, TransactionManager
 from synapse.federation.units import Edu
@@ -421,3 +422,51 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         self.assertNotIn("zzzerver", woken)
         # - all destinations are woken exactly once; they appear once in woken.
         self.assertCountEqual(woken, server_names[:-1])
+
+    @override_config({"send_federation": True})
+    def test_not_latest_event(self):
+        """Test that we send the latest event in the room even if its not ours."""
+
+        per_dest_queue, sent_pdus = self.make_fake_destination_queue()
+
+        # Make a room with a local user, and two servers. One will go offline
+        # and one will send some events.
+        self.register_user("u1", "you the one")
+        u1_token = self.login("u1", "you the one")
+        room_1 = self.helper.create_room_as("u1", tok=u1_token)
+
+        self.get_success(
+            event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join")
+        )
+        event_1 = self.get_success(
+            event_injection.inject_member_event(self.hs, room_1, "@user:host3", "join")
+        )
+
+        # First we send something from the local server, so that we notice the
+        # remote is down and go into catchup mode.
+        self.helper.send(room_1, "you hear me!!", tok=u1_token)
+
+        # Now simulate us receiving an event from the still online remote.
+        event_2 = self.get_success(
+            event_injection.inject_event(
+                self.hs,
+                type=EventTypes.Message,
+                sender="@user:host3",
+                room_id=room_1,
+                content={"msgtype": "m.text", "body": "Hello"},
+            )
+        )
+
+        self.get_success(
+            self.hs.get_datastore().set_destination_last_successful_stream_ordering(
+                "host2", event_1.internal_metadata.stream_ordering
+            )
+        )
+
+        self.get_success(per_dest_queue._catch_up_transmission_loop())
+
+        # We expect only the last message from the remote, event_2, to have been
+        # sent, rather than the last *local* event that was sent.
+        self.assertEqual(len(sent_pdus), 1)
+        self.assertEqual(sent_pdus[0].event_id, event_2.event_id)
+        self.assertFalse(per_dest_queue._catching_up)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 5e9c9c2e88..c7796fb837 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -989,6 +989,138 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         self.assertRenderedError("mapping_error", "localpart is invalid: ")
 
+    @override_config(
+        {
+            "oidc_config": {
+                **DEFAULT_CONFIG,
+                "attribute_requirements": [{"attribute": "test", "value": "foobar"}],
+            }
+        }
+    )
+    def test_attribute_requirements(self):
+        """The required attributes must be met from the OIDC userinfo response."""
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # userinfo lacking "test": "foobar" attribute should fail.
+        userinfo = {
+            "sub": "tester",
+            "username": "tester",
+        }
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # userinfo with "test": "foobar" attribute should succeed.
+        userinfo = {
+            "sub": "tester",
+            "username": "tester",
+            "test": "foobar",
+        }
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@tester:test", "oidc", ANY, ANY, None, new_user=True
+        )
+
+    @override_config(
+        {
+            "oidc_config": {
+                **DEFAULT_CONFIG,
+                "attribute_requirements": [{"attribute": "test", "value": "foobar"}],
+            }
+        }
+    )
+    def test_attribute_requirements_contains(self):
+        """Test that auth succeeds if userinfo attribute CONTAINS required value"""
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+        # userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed.
+        userinfo = {
+            "sub": "tester",
+            "username": "tester",
+            "test": ["foobar", "foo", "bar"],
+        }
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@tester:test", "oidc", ANY, ANY, None, new_user=True
+        )
+
+    @override_config(
+        {
+            "oidc_config": {
+                **DEFAULT_CONFIG,
+                "attribute_requirements": [{"attribute": "test", "value": "foobar"}],
+            }
+        }
+    )
+    def test_attribute_requirements_mismatch(self):
+        """
+        Test that auth fails if attributes exist but don't match,
+        or are non-string values.
+        """
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+        # userinfo with "test": "not_foobar" attribute should fail
+        userinfo = {
+            "sub": "tester",
+            "username": "tester",
+            "test": "not_foobar",
+        }
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # userinfo with "test": ["foo", "bar"] attribute should fail
+        userinfo = {
+            "sub": "tester",
+            "username": "tester",
+            "test": ["foo", "bar"],
+        }
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # userinfo with "test": False attribute should fail
+        # this is largely just to ensure we don't crash here
+        userinfo = {
+            "sub": "tester",
+            "username": "tester",
+            "test": False,
+        }
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # userinfo with "test": None attribute should fail
+        # a value of None breaks the OIDC spec, but it's important to not crash here
+        userinfo = {
+            "sub": "tester",
+            "username": "tester",
+            "test": None,
+        }
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # userinfo with "test": 1 attribute should fail
+        # this is largely just to ensure we don't crash here
+        userinfo = {
+            "sub": "tester",
+            "username": "tester",
+            "test": 1,
+        }
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_not_called()
+
+        # userinfo with "test": 3.14 attribute should fail
+        # this is largely just to ensure we don't crash here
+        userinfo = {
+            "sub": "tester",
+            "username": "tester",
+            "test": 3.14,
+        }
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_not_called()
+
     def _generate_oidc_session_token(
         self,
         state: str,
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 996c614198..77330f59a9 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -310,6 +310,26 @@ class PresenceTimeoutTestCase(unittest.TestCase):
         self.assertIsNotNone(new_state)
         self.assertEquals(new_state.state, PresenceState.UNAVAILABLE)
 
+    def test_busy_no_idle(self):
+        """
+        Tests that a user setting their presence to busy but idling doesn't turn their
+        presence state into unavailable.
+        """
+        user_id = "@foo:bar"
+        now = 5000000
+
+        state = UserPresenceState.default(user_id)
+        state = state.copy_and_replace(
+            state=PresenceState.BUSY,
+            last_active_ts=now - IDLE_TIMER - 1,
+            last_user_sync_ts=now,
+        )
+
+        new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+
+        self.assertIsNotNone(new_state)
+        self.assertEquals(new_state.state, PresenceState.BUSY)
+
     def test_sync_timeout(self):
         user_id = "@foo:bar"
         now = 5000000
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 505ffcd300..3ea8b5bec7 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -12,8 +12,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import base64
 import logging
 import os
+from typing import Optional
 from unittest.mock import patch
 
 import treq
@@ -242,6 +244,21 @@ class MatrixFederationAgentTests(TestCase):
 
     @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
     def test_https_request_via_proxy(self):
+        """Tests that TLS-encrypted requests can be made through a proxy"""
+        self._do_https_request_via_proxy(auth_credentials=None)
+
+    @patch.dict(
+        os.environ,
+        {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
+    )
+    def test_https_request_via_proxy_with_auth(self):
+        """Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
+        self._do_https_request_via_proxy(auth_credentials="bob:pinkponies")
+
+    def _do_https_request_via_proxy(
+        self,
+        auth_credentials: Optional[str] = None,
+    ):
         agent = ProxyAgent(
             self.reactor,
             contextFactory=get_test_https_policy(),
@@ -278,6 +295,22 @@ class MatrixFederationAgentTests(TestCase):
         self.assertEqual(request.method, b"CONNECT")
         self.assertEqual(request.path, b"test.com:443")
 
+        # Check whether auth credentials have been supplied to the proxy
+        proxy_auth_header_values = request.requestHeaders.getRawHeaders(
+            b"Proxy-Authorization"
+        )
+
+        if auth_credentials is not None:
+            # Compute the correct header value for Proxy-Authorization
+            encoded_credentials = base64.b64encode(b"bob:pinkponies")
+            expected_header_value = b"Basic " + encoded_credentials
+
+            # Validate the header's value
+            self.assertIn(expected_header_value, proxy_auth_header_values)
+        else:
+            # Check that the Proxy-Authorization header has not been supplied to the proxy
+            self.assertIsNone(proxy_auth_header_values)
+
         # tell the proxy server not to close the connection
         proxy_server.persistent = True
 
@@ -312,6 +345,13 @@ class MatrixFederationAgentTests(TestCase):
         self.assertEqual(request.method, b"GET")
         self.assertEqual(request.path, b"/abc")
         self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+
+        # Check that the destination server DID NOT receive proxy credentials
+        proxy_auth_header_values = request.requestHeaders.getRawHeaders(
+            b"Proxy-Authorization"
+        )
+        self.assertIsNone(proxy_auth_header_values)
+
         request.write(b"result")
         request.finish()
 
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index e58d5cf0db..cf61f284cb 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1003,12 +1003,23 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
+        self.auth_handler = hs.get_auth_handler()
 
+        # create users and get access tokens
+        # regardless of whether password login or SSO is allowed
         self.admin_user = self.register_user("admin", "pass", admin=True)
-        self.admin_user_tok = self.login("admin", "pass")
+        self.admin_user_tok = self.get_success(
+            self.auth_handler.get_access_token_for_user_id(
+                self.admin_user, device_id=None, valid_until_ms=None
+            )
+        )
 
         self.other_user = self.register_user("user", "pass", displayname="User")
-        self.other_user_token = self.login("user", "pass")
+        self.other_user_token = self.get_success(
+            self.auth_handler.get_access_token_for_user_id(
+                self.other_user, device_id=None, valid_until_ms=None
+            )
+        )
         self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
             self.other_user
         )
@@ -1081,7 +1092,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("Bob's name", channel.json_body["displayname"])
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
-        self.assertEqual(True, channel.json_body["admin"])
+        self.assertTrue(channel.json_body["admin"])
         self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
         # Get user
@@ -1096,9 +1107,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("Bob's name", channel.json_body["displayname"])
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
-        self.assertEqual(True, channel.json_body["admin"])
-        self.assertEqual(False, channel.json_body["is_guest"])
-        self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertTrue(channel.json_body["admin"])
+        self.assertFalse(channel.json_body["is_guest"])
+        self.assertFalse(channel.json_body["deactivated"])
         self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
     def test_create_user(self):
@@ -1130,7 +1141,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("Bob's name", channel.json_body["displayname"])
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
-        self.assertEqual(False, channel.json_body["admin"])
+        self.assertFalse(channel.json_body["admin"])
         self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
         # Get user
@@ -1145,10 +1156,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("Bob's name", channel.json_body["displayname"])
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
-        self.assertEqual(False, channel.json_body["admin"])
-        self.assertEqual(False, channel.json_body["is_guest"])
-        self.assertEqual(False, channel.json_body["deactivated"])
-        self.assertEqual(False, channel.json_body["shadow_banned"])
+        self.assertFalse(channel.json_body["admin"])
+        self.assertFalse(channel.json_body["is_guest"])
+        self.assertFalse(channel.json_body["deactivated"])
+        self.assertFalse(channel.json_body["shadow_banned"])
         self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
     @override_config(
@@ -1197,7 +1208,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["name"])
-        self.assertEqual(False, channel.json_body["admin"])
+        self.assertFalse(channel.json_body["admin"])
 
     @override_config(
         {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
@@ -1237,7 +1248,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Admin user is not blocked by mau anymore
         self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["name"])
-        self.assertEqual(False, channel.json_body["admin"])
+        self.assertFalse(channel.json_body["admin"])
 
     @override_config(
         {
@@ -1429,24 +1440,23 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertFalse(channel.json_body["deactivated"])
         self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
         self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
         self.assertEqual("User", channel.json_body["displayname"])
 
         # Deactivate user
-        body = json.dumps({"deactivated": True})
-
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"deactivated": True},
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertTrue(channel.json_body["deactivated"])
+        self.assertIsNone(channel.json_body["password_hash"])
         self.assertEqual(0, len(channel.json_body["threepids"]))
         self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
         self.assertEqual("User", channel.json_body["displayname"])
@@ -1461,7 +1471,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertTrue(channel.json_body["deactivated"])
+        self.assertIsNone(channel.json_body["password_hash"])
         self.assertEqual(0, len(channel.json_body["threepids"]))
         self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
         self.assertEqual("User", channel.json_body["displayname"])
@@ -1478,41 +1489,37 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertTrue(profile["display_name"] == "User")
 
         # Deactivate user
-        body = json.dumps({"deactivated": True})
-
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"deactivated": True},
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertTrue(channel.json_body["deactivated"])
 
         # is not in user directory
         profile = self.get_success(self.store.get_user_in_directory(self.other_user))
-        self.assertTrue(profile is None)
+        self.assertIsNone(profile)
 
         # Set new displayname user
-        body = json.dumps({"displayname": "Foobar"})
-
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"displayname": "Foobar"},
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertTrue(channel.json_body["deactivated"])
         self.assertEqual("Foobar", channel.json_body["displayname"])
 
         # is not in user directory
         profile = self.get_success(self.store.get_user_in_directory(self.other_user))
-        self.assertTrue(profile is None)
+        self.assertIsNone(profile)
 
     def test_reactivate_user(self):
         """
@@ -1520,48 +1527,92 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         """
 
         # Deactivate the user.
+        self._deactivate_user("@user:test")
+
+        # Attempt to reactivate the user (without a password).
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content={"deactivated": False},
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Reactivate the user.
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content=json.dumps({"deactivated": True}).encode(encoding="utf_8"),
+            content={"deactivated": False, "password": "foo"},
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertFalse(channel.json_body["deactivated"])
+        self.assertIsNotNone(channel.json_body["password_hash"])
         self._is_erased("@user:test", False)
-        d = self.store.mark_user_erased("@user:test")
-        self.assertIsNone(self.get_success(d))
-        self._is_erased("@user:test", True)
 
-        # Attempt to reactivate the user (without a password).
+    @override_config({"password_config": {"localdb_enabled": False}})
+    def test_reactivate_user_localdb_disabled(self):
+        """
+        Test reactivating another user when using SSO.
+        """
+
+        # Deactivate the user.
+        self._deactivate_user("@user:test")
+
+        # Reactivate the user with a password
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content=json.dumps({"deactivated": False}).encode(encoding="utf_8"),
+            content={"deactivated": False, "password": "foo"},
         )
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-        # Reactivate the user.
+        # Reactivate the user without a password.
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content=json.dumps({"deactivated": False, "password": "foo"}).encode(
-                encoding="utf_8"
-            ),
+            content={"deactivated": False},
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertFalse(channel.json_body["deactivated"])
+        self.assertIsNone(channel.json_body["password_hash"])
+        self._is_erased("@user:test", False)
 
-        # Get user
+    @override_config({"password_config": {"enabled": False}})
+    def test_reactivate_user_password_disabled(self):
+        """
+        Test reactivating another user when using SSO.
+        """
+
+        # Deactivate the user.
+        self._deactivate_user("@user:test")
+
+        # Reactivate the user with a password
         channel = self.make_request(
-            "GET",
+            "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
+            content={"deactivated": False, "password": "foo"},
         )
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
+        # Reactivate the user without a password.
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content={"deactivated": False},
+        )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertFalse(channel.json_body["deactivated"])
+        self.assertIsNone(channel.json_body["password_hash"])
         self._is_erased("@user:test", False)
 
     def test_set_user_as_admin(self):
@@ -1570,18 +1621,16 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         """
 
         # Set a user as an admin
-        body = json.dumps({"admin": True})
-
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"admin": True},
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(True, channel.json_body["admin"])
+        self.assertTrue(channel.json_body["admin"])
 
         # Get user
         channel = self.make_request(
@@ -1592,7 +1641,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(True, channel.json_body["admin"])
+        self.assertTrue(channel.json_body["admin"])
 
     def test_accidental_deactivation_prevention(self):
         """
@@ -1602,13 +1651,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         url = "/_synapse/admin/v2/users/@bob:test"
 
         # Create user
-        body = json.dumps({"password": "abc123"})
-
         channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"password": "abc123"},
         )
 
         self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
@@ -1628,13 +1675,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(0, channel.json_body["deactivated"])
 
         # Change password (and use a str for deactivate instead of a bool)
-        body = json.dumps({"password": "abc123", "deactivated": "false"})  # oops!
-
         channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"password": "abc123", "deactivated": "false"},
         )
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -1653,7 +1698,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Ensure they're still alive
         self.assertEqual(0, channel.json_body["deactivated"])
 
-    def _is_erased(self, user_id, expect):
+    def _is_erased(self, user_id: str, expect: bool) -> None:
         """Assert that the user is erased or not"""
         d = self.store.is_user_erased(user_id)
         if expect:
@@ -1661,6 +1706,24 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         else:
             self.assertFalse(self.get_success(d))
 
+    def _deactivate_user(self, user_id: str) -> None:
+        """Deactivate user and set as erased"""
+
+        # Deactivate the user.
+        channel = self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/%s" % urllib.parse.quote(user_id),
+            access_token=self.admin_user_tok,
+            content={"deactivated": True},
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertTrue(channel.json_body["deactivated"])
+        self.assertIsNone(channel.json_body["password_hash"])
+        self._is_erased(user_id, False)
+        d = self.store.mark_user_erased(user_id)
+        self.assertIsNone(self.get_success(d))
+        self._is_erased(user_id, True)
+
 
 class UserMembershipRestTestCase(unittest.HomeserverTestCase):
 
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 227fffab58..bf39014277 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -161,6 +161,68 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
         ev = channel.json_body
         self.assertEqual(ev["content"]["x"], "y")
 
+    def test_message_edit(self):
+        """Ensure that the module doesn't cause issues with edited messages."""
+        # first patch the event checker so that it will modify the event
+        async def check(ev: EventBase, state):
+            d = ev.get_dict()
+            d["content"] = {
+                "msgtype": "m.text",
+                "body": d["content"]["body"].upper(),
+            }
+            return d
+
+        current_rules_module().check_event_allowed = check
+
+        # Send an event, then edit it.
+        channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
+            {
+                "msgtype": "m.text",
+                "body": "Original body",
+            },
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        orig_event_id = channel.json_body["event_id"]
+
+        channel = self.make_request(
+            "PUT",
+            "/_matrix/client/r0/rooms/%s/send/m.room.message/2" % self.room_id,
+            {
+                "m.new_content": {"msgtype": "m.text", "body": "Edited body"},
+                "m.relates_to": {
+                    "rel_type": "m.replace",
+                    "event_id": orig_event_id,
+                },
+                "msgtype": "m.text",
+                "body": "Edited body",
+            },
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        edited_event_id = channel.json_body["event_id"]
+
+        # ... and check that they both got modified
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, orig_event_id),
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        ev = channel.json_body
+        self.assertEqual(ev["content"]["body"], "ORIGINAL BODY")
+
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, edited_event_id),
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.result["code"], b"200", channel.result)
+        ev = channel.json_body
+        self.assertEqual(ev["content"]["body"], "EDITED BODY")
+
     def test_send_event(self):
         """Tests that the module can send an event into a room via the module api"""
         content = {
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
index e808339fb3..287a1a485c 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -18,6 +18,7 @@ from synapse.rest.client.v1 import login
 from synapse.rest.client.v2_alpha import capabilities
 
 from tests import unittest
+from tests.unittest import override_config
 
 
 class CapabilitiesTestCase(unittest.HomeserverTestCase):
@@ -33,6 +34,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         hs = self.setup_test_homeserver()
         self.store = hs.get_datastore()
         self.config = hs.config
+        self.auth_handler = hs.get_auth_handler()
         return hs
 
     def test_check_auth_required(self):
@@ -56,7 +58,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
             capabilities["m.room_versions"]["default"],
         )
 
-    def test_get_change_password_capabilities(self):
+    def test_get_change_password_capabilities_password_login(self):
         localpart = "user"
         password = "pass"
         user = self.register_user(localpart, password)
@@ -66,10 +68,36 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         capabilities = channel.json_body["capabilities"]
 
         self.assertEqual(channel.code, 200)
-
-        # Test case where password is handled outside of Synapse
         self.assertTrue(capabilities["m.change_password"]["enabled"])
-        self.get_success(self.store.user_set_password_hash(user, None))
+
+    @override_config({"password_config": {"localdb_enabled": False}})
+    def test_get_change_password_capabilities_localdb_disabled(self):
+        localpart = "user"
+        password = "pass"
+        user = self.register_user(localpart, password)
+        access_token = self.get_success(
+            self.auth_handler.get_access_token_for_user_id(
+                user, device_id=None, valid_until_ms=None
+            )
+        )
+
+        channel = self.make_request("GET", self.url, access_token=access_token)
+        capabilities = channel.json_body["capabilities"]
+
+        self.assertEqual(channel.code, 200)
+        self.assertFalse(capabilities["m.change_password"]["enabled"])
+
+    @override_config({"password_config": {"enabled": False}})
+    def test_get_change_password_capabilities_password_disabled(self):
+        localpart = "user"
+        password = "pass"
+        user = self.register_user(localpart, password)
+        access_token = self.get_success(
+            self.auth_handler.get_access_token_for_user_id(
+                user, device_id=None, valid_until_ms=None
+            )
+        )
+
         channel = self.make_request("GET", self.url, access_token=access_token)
         capabilities = channel.json_body["capabilities"]
 
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 7c457754f1..e7bb5583fc 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -39,6 +39,11 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         # We need to enable msc1849 support for aggregations
         config = self.default_config()
         config["experimental_msc1849_support_enabled"] = True
+
+        # We enable frozen dicts as relations/edits change event contents, so we
+        # want to test that we don't modify the events in the caches.
+        config["use_frozen_dicts"] = True
+
         return self.setup_test_homeserver(config=config)
 
     def prepare(self, reactor, clock, hs):
@@ -518,6 +523,63 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
         )
 
+    def test_edit_reply(self):
+        """Test that editing a reply works."""
+
+        # Create a reply to edit.
+        channel = self._send_relation(
+            RelationTypes.REFERENCE,
+            "m.room.message",
+            content={"msgtype": "m.text", "body": "A reply!"},
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        reply = channel.json_body["event_id"]
+
+        new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+        channel = self._send_relation(
+            RelationTypes.REPLACE,
+            "m.room.message",
+            content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+            parent_id=reply,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        edit_event_id = channel.json_body["event_id"]
+
+        channel = self.make_request(
+            "GET",
+            "/rooms/%s/event/%s" % (self.room, reply),
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        # We expect to see the new body in the dict, as well as the reference
+        # metadata sill intact.
+        self.assertDictContainsSubset(new_body, channel.json_body["content"])
+        self.assertDictContainsSubset(
+            {
+                "m.relates_to": {
+                    "event_id": self.parent_id,
+                    "key": None,
+                    "rel_type": "m.reference",
+                }
+            },
+            channel.json_body["content"],
+        )
+
+        # We expect that the edit relation appears in the unsigned relations
+        # section.
+        relations_dict = channel.json_body["unsigned"].get("m.relations")
+        self.assertIn(RelationTypes.REPLACE, relations_dict)
+
+        m_replace_dict = relations_dict[RelationTypes.REPLACE]
+        for key in ["event_id", "sender", "origin_server_ts"]:
+            self.assertIn(key, m_replace_dict)
+
+        self.assert_dict(
+            {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
+        )
+
     def test_relations_redaction_redacts_edits(self):
         """Test that edits of an event are redacted when the original event
         is redacted.
diff --git a/tests/server.py b/tests/server.py
index 2287d20076..57cc4ac605 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -593,7 +593,7 @@ class FakeTransport:
         if self.disconnected:
             return
 
-        if getattr(self.other, "transport") is None:
+        if not hasattr(self.other, "transport"):
             # the other has no transport yet; reschedule
             if self.autoflush:
                 self._reactor.callLater(0.0, self.flush)
diff --git a/tests/unittest.py b/tests/unittest.py
index ca7031c724..58a4daa1ec 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -32,6 +32,7 @@ from twisted.python.threadpool import ThreadPool
 from twisted.trial import unittest
 from twisted.web.resource import Resource
 
+from synapse import events
 from synapse.api.constants import EventTypes, Membership
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.ratelimiting import FederationRateLimitConfig
@@ -140,7 +141,7 @@ class TestCase(unittest.TestCase):
             try:
                 self.assertEquals(attrs[key], getattr(obj, key))
             except AssertionError as e:
-                raise (type(e))(e.message + " for '.%s'" % key)
+                raise (type(e))("Assert error for '.{}':".format(key)) from e
 
     def assert_dict(self, required, actual):
         """Does a partial assert of a dict.
@@ -229,6 +230,11 @@ class HomeserverTestCase(TestCase):
         self._hs_args = {"clock": self.clock, "reactor": self.reactor}
         self.hs = self.make_homeserver(self.reactor, self.clock)
 
+        # Honour the `use_frozen_dicts` config option. We have to do this
+        # manually because this is taken care of in the app `start` code, which
+        # we don't run. Plus we want to reset it on tearDown.
+        events.USE_FROZEN_DICTS = self.hs.config.use_frozen_dicts
+
         if self.hs is None:
             raise Exception("No homeserver returned from make_homeserver.")
 
@@ -292,6 +298,10 @@ class HomeserverTestCase(TestCase):
         if hasattr(self, "prepare"):
             self.prepare(self.reactor, self.clock, self.hs)
 
+    def tearDown(self):
+        # Reset to not use frozen dicts.
+        events.USE_FROZEN_DICTS = False
+
     def wait_on_thread(self, deferred, timeout=10):
         """
         Wait until a Deferred is done, where it's waiting on a real thread.