diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 6f96cd7940..95eac6a5a3 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -2,6 +2,7 @@ from typing import List, Tuple
from mock import Mock
+from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.federation.sender import PerDestinationQueue, TransactionManager
from synapse.federation.units import Edu
@@ -421,3 +422,51 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
self.assertNotIn("zzzerver", woken)
# - all destinations are woken exactly once; they appear once in woken.
self.assertCountEqual(woken, server_names[:-1])
+
+ @override_config({"send_federation": True})
+ def test_not_latest_event(self):
+ """Test that we send the latest event in the room even if its not ours."""
+
+ per_dest_queue, sent_pdus = self.make_fake_destination_queue()
+
+ # Make a room with a local user, and two servers. One will go offline
+ # and one will send some events.
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room_1 = self.helper.create_room_as("u1", tok=u1_token)
+
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join")
+ )
+ event_1 = self.get_success(
+ event_injection.inject_member_event(self.hs, room_1, "@user:host3", "join")
+ )
+
+ # First we send something from the local server, so that we notice the
+ # remote is down and go into catchup mode.
+ self.helper.send(room_1, "you hear me!!", tok=u1_token)
+
+ # Now simulate us receiving an event from the still online remote.
+ event_2 = self.get_success(
+ event_injection.inject_event(
+ self.hs,
+ type=EventTypes.Message,
+ sender="@user:host3",
+ room_id=room_1,
+ content={"msgtype": "m.text", "body": "Hello"},
+ )
+ )
+
+ self.get_success(
+ self.hs.get_datastore().set_destination_last_successful_stream_ordering(
+ "host2", event_1.internal_metadata.stream_ordering
+ )
+ )
+
+ self.get_success(per_dest_queue._catch_up_transmission_loop())
+
+ # We expect only the last message from the remote, event_2, to have been
+ # sent, rather than the last *local* event that was sent.
+ self.assertEqual(len(sent_pdus), 1)
+ self.assertEqual(sent_pdus[0].event_id, event_2.event_id)
+ self.assertFalse(per_dest_queue._catching_up)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 5e9c9c2e88..c7796fb837 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -989,6 +989,138 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
self.assertRenderedError("mapping_error", "localpart is invalid: ")
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "attribute_requirements": [{"attribute": "test", "value": "foobar"}],
+ }
+ }
+ )
+ def test_attribute_requirements(self):
+ """The required attributes must be met from the OIDC userinfo response."""
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # userinfo lacking "test": "foobar" attribute should fail.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": "foobar" attribute should succeed.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": "foobar",
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@tester:test", "oidc", ANY, ANY, None, new_user=True
+ )
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "attribute_requirements": [{"attribute": "test", "value": "foobar"}],
+ }
+ }
+ )
+ def test_attribute_requirements_contains(self):
+ """Test that auth succeeds if userinfo attribute CONTAINS required value"""
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+ # userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": ["foobar", "foo", "bar"],
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@tester:test", "oidc", ANY, ANY, None, new_user=True
+ )
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "attribute_requirements": [{"attribute": "test", "value": "foobar"}],
+ }
+ }
+ )
+ def test_attribute_requirements_mismatch(self):
+ """
+ Test that auth fails if attributes exist but don't match,
+ or are non-string values.
+ """
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+ # userinfo with "test": "not_foobar" attribute should fail
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": "not_foobar",
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": ["foo", "bar"] attribute should fail
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": ["foo", "bar"],
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": False attribute should fail
+ # this is largely just to ensure we don't crash here
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": False,
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": None attribute should fail
+ # a value of None breaks the OIDC spec, but it's important to not crash here
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": None,
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": 1 attribute should fail
+ # this is largely just to ensure we don't crash here
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": 1,
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": 3.14 attribute should fail
+ # this is largely just to ensure we don't crash here
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": 3.14,
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
def _generate_oidc_session_token(
self,
state: str,
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 996c614198..77330f59a9 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -310,6 +310,26 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertIsNotNone(new_state)
self.assertEquals(new_state.state, PresenceState.UNAVAILABLE)
+ def test_busy_no_idle(self):
+ """
+ Tests that a user setting their presence to busy but idling doesn't turn their
+ presence state into unavailable.
+ """
+ user_id = "@foo:bar"
+ now = 5000000
+
+ state = UserPresenceState.default(user_id)
+ state = state.copy_and_replace(
+ state=PresenceState.BUSY,
+ last_active_ts=now - IDLE_TIMER - 1,
+ last_user_sync_ts=now,
+ )
+
+ new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+
+ self.assertIsNotNone(new_state)
+ self.assertEquals(new_state.state, PresenceState.BUSY)
+
def test_sync_timeout(self):
user_id = "@foo:bar"
now = 5000000
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
index 505ffcd300..3ea8b5bec7 100644
--- a/tests/http/test_proxyagent.py
+++ b/tests/http/test_proxyagent.py
@@ -12,8 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import base64
import logging
import os
+from typing import Optional
from unittest.mock import patch
import treq
@@ -242,6 +244,21 @@ class MatrixFederationAgentTests(TestCase):
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
def test_https_request_via_proxy(self):
+ """Tests that TLS-encrypted requests can be made through a proxy"""
+ self._do_https_request_via_proxy(auth_credentials=None)
+
+ @patch.dict(
+ os.environ,
+ {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
+ )
+ def test_https_request_via_proxy_with_auth(self):
+ """Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
+ self._do_https_request_via_proxy(auth_credentials="bob:pinkponies")
+
+ def _do_https_request_via_proxy(
+ self,
+ auth_credentials: Optional[str] = None,
+ ):
agent = ProxyAgent(
self.reactor,
contextFactory=get_test_https_policy(),
@@ -278,6 +295,22 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(request.method, b"CONNECT")
self.assertEqual(request.path, b"test.com:443")
+ # Check whether auth credentials have been supplied to the proxy
+ proxy_auth_header_values = request.requestHeaders.getRawHeaders(
+ b"Proxy-Authorization"
+ )
+
+ if auth_credentials is not None:
+ # Compute the correct header value for Proxy-Authorization
+ encoded_credentials = base64.b64encode(b"bob:pinkponies")
+ expected_header_value = b"Basic " + encoded_credentials
+
+ # Validate the header's value
+ self.assertIn(expected_header_value, proxy_auth_header_values)
+ else:
+ # Check that the Proxy-Authorization header has not been supplied to the proxy
+ self.assertIsNone(proxy_auth_header_values)
+
# tell the proxy server not to close the connection
proxy_server.persistent = True
@@ -312,6 +345,13 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(request.method, b"GET")
self.assertEqual(request.path, b"/abc")
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+
+ # Check that the destination server DID NOT receive proxy credentials
+ proxy_auth_header_values = request.requestHeaders.getRawHeaders(
+ b"Proxy-Authorization"
+ )
+ self.assertIsNone(proxy_auth_header_values)
+
request.write(b"result")
request.finish()
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index e58d5cf0db..cf61f284cb 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1003,12 +1003,23 @@ class UserRestTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
+ self.auth_handler = hs.get_auth_handler()
+ # create users and get access tokens
+ # regardless of whether password login or SSO is allowed
self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
+ self.admin_user_tok = self.get_success(
+ self.auth_handler.get_access_token_for_user_id(
+ self.admin_user, device_id=None, valid_until_ms=None
+ )
+ )
self.other_user = self.register_user("user", "pass", displayname="User")
- self.other_user_token = self.login("user", "pass")
+ self.other_user_token = self.get_success(
+ self.auth_handler.get_access_token_for_user_id(
+ self.other_user, device_id=None, valid_until_ms=None
+ )
+ )
self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
self.other_user
)
@@ -1081,7 +1092,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual(True, channel.json_body["admin"])
+ self.assertTrue(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user
@@ -1096,9 +1107,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual(True, channel.json_body["admin"])
- self.assertEqual(False, channel.json_body["is_guest"])
- self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertTrue(channel.json_body["admin"])
+ self.assertFalse(channel.json_body["is_guest"])
+ self.assertFalse(channel.json_body["deactivated"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
def test_create_user(self):
@@ -1130,7 +1141,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual(False, channel.json_body["admin"])
+ self.assertFalse(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user
@@ -1145,10 +1156,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual(False, channel.json_body["admin"])
- self.assertEqual(False, channel.json_body["is_guest"])
- self.assertEqual(False, channel.json_body["deactivated"])
- self.assertEqual(False, channel.json_body["shadow_banned"])
+ self.assertFalse(channel.json_body["admin"])
+ self.assertFalse(channel.json_body["is_guest"])
+ self.assertFalse(channel.json_body["deactivated"])
+ self.assertFalse(channel.json_body["shadow_banned"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
@override_config(
@@ -1197,7 +1208,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
- self.assertEqual(False, channel.json_body["admin"])
+ self.assertFalse(channel.json_body["admin"])
@override_config(
{"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
@@ -1237,7 +1248,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Admin user is not blocked by mau anymore
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
- self.assertEqual(False, channel.json_body["admin"])
+ self.assertFalse(channel.json_body["admin"])
@override_config(
{
@@ -1429,24 +1440,23 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertFalse(channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
# Deactivate user
- body = json.dumps({"deactivated": True})
-
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"deactivated": True},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertTrue(channel.json_body["deactivated"])
+ self.assertIsNone(channel.json_body["password_hash"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
@@ -1461,7 +1471,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertTrue(channel.json_body["deactivated"])
+ self.assertIsNone(channel.json_body["password_hash"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
@@ -1478,41 +1489,37 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertTrue(profile["display_name"] == "User")
# Deactivate user
- body = json.dumps({"deactivated": True})
-
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"deactivated": True},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertTrue(channel.json_body["deactivated"])
# is not in user directory
profile = self.get_success(self.store.get_user_in_directory(self.other_user))
- self.assertTrue(profile is None)
+ self.assertIsNone(profile)
# Set new displayname user
- body = json.dumps({"displayname": "Foobar"})
-
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"displayname": "Foobar"},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertTrue(channel.json_body["deactivated"])
self.assertEqual("Foobar", channel.json_body["displayname"])
# is not in user directory
profile = self.get_success(self.store.get_user_in_directory(self.other_user))
- self.assertTrue(profile is None)
+ self.assertIsNone(profile)
def test_reactivate_user(self):
"""
@@ -1520,48 +1527,92 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
# Deactivate the user.
+ self._deactivate_user("@user:test")
+
+ # Attempt to reactivate the user (without a password).
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"deactivated": False},
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Reactivate the user.
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=json.dumps({"deactivated": True}).encode(encoding="utf_8"),
+ content={"deactivated": False, "password": "foo"},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertFalse(channel.json_body["deactivated"])
+ self.assertIsNotNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False)
- d = self.store.mark_user_erased("@user:test")
- self.assertIsNone(self.get_success(d))
- self._is_erased("@user:test", True)
- # Attempt to reactivate the user (without a password).
+ @override_config({"password_config": {"localdb_enabled": False}})
+ def test_reactivate_user_localdb_disabled(self):
+ """
+ Test reactivating another user when using SSO.
+ """
+
+ # Deactivate the user.
+ self._deactivate_user("@user:test")
+
+ # Reactivate the user with a password
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=json.dumps({"deactivated": False}).encode(encoding="utf_8"),
+ content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- # Reactivate the user.
+ # Reactivate the user without a password.
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=json.dumps({"deactivated": False, "password": "foo"}).encode(
- encoding="utf_8"
- ),
+ content={"deactivated": False},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertFalse(channel.json_body["deactivated"])
+ self.assertIsNone(channel.json_body["password_hash"])
+ self._is_erased("@user:test", False)
- # Get user
+ @override_config({"password_config": {"enabled": False}})
+ def test_reactivate_user_password_disabled(self):
+ """
+ Test reactivating another user when using SSO.
+ """
+
+ # Deactivate the user.
+ self._deactivate_user("@user:test")
+
+ # Reactivate the user with a password
channel = self.make_request(
- "GET",
+ "PUT",
self.url_other_user,
access_token=self.admin_user_tok,
+ content={"deactivated": False, "password": "foo"},
)
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+ # Reactivate the user without a password.
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"deactivated": False},
+ )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertFalse(channel.json_body["deactivated"])
+ self.assertIsNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False)
def test_set_user_as_admin(self):
@@ -1570,18 +1621,16 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
# Set a user as an admin
- body = json.dumps({"admin": True})
-
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"admin": True},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["admin"])
+ self.assertTrue(channel.json_body["admin"])
# Get user
channel = self.make_request(
@@ -1592,7 +1641,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["admin"])
+ self.assertTrue(channel.json_body["admin"])
def test_accidental_deactivation_prevention(self):
"""
@@ -1602,13 +1651,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v2/users/@bob:test"
# Create user
- body = json.dumps({"password": "abc123"})
-
channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"password": "abc123"},
)
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
@@ -1628,13 +1675,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, channel.json_body["deactivated"])
# Change password (and use a str for deactivate instead of a bool)
- body = json.dumps({"password": "abc123", "deactivated": "false"}) # oops!
-
channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"password": "abc123", "deactivated": "false"},
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -1653,7 +1698,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Ensure they're still alive
self.assertEqual(0, channel.json_body["deactivated"])
- def _is_erased(self, user_id, expect):
+ def _is_erased(self, user_id: str, expect: bool) -> None:
"""Assert that the user is erased or not"""
d = self.store.is_user_erased(user_id)
if expect:
@@ -1661,6 +1706,24 @@ class UserRestTestCase(unittest.HomeserverTestCase):
else:
self.assertFalse(self.get_success(d))
+ def _deactivate_user(self, user_id: str) -> None:
+ """Deactivate user and set as erased"""
+
+ # Deactivate the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % urllib.parse.quote(user_id),
+ access_token=self.admin_user_tok,
+ content={"deactivated": True},
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertTrue(channel.json_body["deactivated"])
+ self.assertIsNone(channel.json_body["password_hash"])
+ self._is_erased(user_id, False)
+ d = self.store.mark_user_erased(user_id)
+ self.assertIsNone(self.get_success(d))
+ self._is_erased(user_id, True)
+
class UserMembershipRestTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 227fffab58..bf39014277 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -161,6 +161,68 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
ev = channel.json_body
self.assertEqual(ev["content"]["x"], "y")
+ def test_message_edit(self):
+ """Ensure that the module doesn't cause issues with edited messages."""
+ # first patch the event checker so that it will modify the event
+ async def check(ev: EventBase, state):
+ d = ev.get_dict()
+ d["content"] = {
+ "msgtype": "m.text",
+ "body": d["content"]["body"].upper(),
+ }
+ return d
+
+ current_rules_module().check_event_allowed = check
+
+ # Send an event, then edit it.
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
+ {
+ "msgtype": "m.text",
+ "body": "Original body",
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ orig_event_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/m.room.message/2" % self.room_id,
+ {
+ "m.new_content": {"msgtype": "m.text", "body": "Edited body"},
+ "m.relates_to": {
+ "rel_type": "m.replace",
+ "event_id": orig_event_id,
+ },
+ "msgtype": "m.text",
+ "body": "Edited body",
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ edited_event_id = channel.json_body["event_id"]
+
+ # ... and check that they both got modified
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, orig_event_id),
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ ev = channel.json_body
+ self.assertEqual(ev["content"]["body"], "ORIGINAL BODY")
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, edited_event_id),
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ ev = channel.json_body
+ self.assertEqual(ev["content"]["body"], "EDITED BODY")
+
def test_send_event(self):
"""Tests that the module can send an event into a room via the module api"""
content = {
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
index e808339fb3..287a1a485c 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -18,6 +18,7 @@ from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import capabilities
from tests import unittest
+from tests.unittest import override_config
class CapabilitiesTestCase(unittest.HomeserverTestCase):
@@ -33,6 +34,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver()
self.store = hs.get_datastore()
self.config = hs.config
+ self.auth_handler = hs.get_auth_handler()
return hs
def test_check_auth_required(self):
@@ -56,7 +58,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
capabilities["m.room_versions"]["default"],
)
- def test_get_change_password_capabilities(self):
+ def test_get_change_password_capabilities_password_login(self):
localpart = "user"
password = "pass"
user = self.register_user(localpart, password)
@@ -66,10 +68,36 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
-
- # Test case where password is handled outside of Synapse
self.assertTrue(capabilities["m.change_password"]["enabled"])
- self.get_success(self.store.user_set_password_hash(user, None))
+
+ @override_config({"password_config": {"localdb_enabled": False}})
+ def test_get_change_password_capabilities_localdb_disabled(self):
+ localpart = "user"
+ password = "pass"
+ user = self.register_user(localpart, password)
+ access_token = self.get_success(
+ self.auth_handler.get_access_token_for_user_id(
+ user, device_id=None, valid_until_ms=None
+ )
+ )
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertFalse(capabilities["m.change_password"]["enabled"])
+
+ @override_config({"password_config": {"enabled": False}})
+ def test_get_change_password_capabilities_password_disabled(self):
+ localpart = "user"
+ password = "pass"
+ user = self.register_user(localpart, password)
+ access_token = self.get_success(
+ self.auth_handler.get_access_token_for_user_id(
+ user, device_id=None, valid_until_ms=None
+ )
+ )
+
channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 7c457754f1..e7bb5583fc 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -39,6 +39,11 @@ class RelationsTestCase(unittest.HomeserverTestCase):
# We need to enable msc1849 support for aggregations
config = self.default_config()
config["experimental_msc1849_support_enabled"] = True
+
+ # We enable frozen dicts as relations/edits change event contents, so we
+ # want to test that we don't modify the events in the caches.
+ config["use_frozen_dicts"] = True
+
return self.setup_test_homeserver(config=config)
def prepare(self, reactor, clock, hs):
@@ -518,6 +523,63 @@ class RelationsTestCase(unittest.HomeserverTestCase):
{"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)
+ def test_edit_reply(self):
+ """Test that editing a reply works."""
+
+ # Create a reply to edit.
+ channel = self._send_relation(
+ RelationTypes.REFERENCE,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "A reply!"},
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ reply = channel.json_body["event_id"]
+
+ new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+ parent_id=reply,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ edit_event_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ "/rooms/%s/event/%s" % (self.room, reply),
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # We expect to see the new body in the dict, as well as the reference
+ # metadata sill intact.
+ self.assertDictContainsSubset(new_body, channel.json_body["content"])
+ self.assertDictContainsSubset(
+ {
+ "m.relates_to": {
+ "event_id": self.parent_id,
+ "key": None,
+ "rel_type": "m.reference",
+ }
+ },
+ channel.json_body["content"],
+ )
+
+ # We expect that the edit relation appears in the unsigned relations
+ # section.
+ relations_dict = channel.json_body["unsigned"].get("m.relations")
+ self.assertIn(RelationTypes.REPLACE, relations_dict)
+
+ m_replace_dict = relations_dict[RelationTypes.REPLACE]
+ for key in ["event_id", "sender", "origin_server_ts"]:
+ self.assertIn(key, m_replace_dict)
+
+ self.assert_dict(
+ {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
+ )
+
def test_relations_redaction_redacts_edits(self):
"""Test that edits of an event are redacted when the original event
is redacted.
diff --git a/tests/unittest.py b/tests/unittest.py
index ca7031c724..58a4daa1ec 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -32,6 +32,7 @@ from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
from twisted.web.resource import Resource
+from synapse import events
from synapse.api.constants import EventTypes, Membership
from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRateLimitConfig
@@ -140,7 +141,7 @@ class TestCase(unittest.TestCase):
try:
self.assertEquals(attrs[key], getattr(obj, key))
except AssertionError as e:
- raise (type(e))(e.message + " for '.%s'" % key)
+ raise (type(e))("Assert error for '.{}':".format(key)) from e
def assert_dict(self, required, actual):
"""Does a partial assert of a dict.
@@ -229,6 +230,11 @@ class HomeserverTestCase(TestCase):
self._hs_args = {"clock": self.clock, "reactor": self.reactor}
self.hs = self.make_homeserver(self.reactor, self.clock)
+ # Honour the `use_frozen_dicts` config option. We have to do this
+ # manually because this is taken care of in the app `start` code, which
+ # we don't run. Plus we want to reset it on tearDown.
+ events.USE_FROZEN_DICTS = self.hs.config.use_frozen_dicts
+
if self.hs is None:
raise Exception("No homeserver returned from make_homeserver.")
@@ -292,6 +298,10 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "prepare"):
self.prepare(self.reactor, self.clock, self.hs)
+ def tearDown(self):
+ # Reset to not use frozen dicts.
+ events.USE_FROZEN_DICTS = False
+
def wait_on_thread(self, deferred, timeout=10):
"""
Wait until a Deferred is done, where it's waiting on a real thread.
|