diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 95475bb651..d4e7fa1293 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -45,50 +45,63 @@ class RetentionTestCase(unittest.HomeserverTestCase):
}
self.hs = self.setup_test_homeserver(config=config)
+
return self.hs
def prepare(self, reactor, clock, homeserver):
self.user_id = self.register_user("user", "password")
self.token = self.login("user", "password")
- def test_retention_state_event(self):
- """Tests that the server configuration can limit the values a user can set to the
- room's retention policy.
+ self.store = self.hs.get_datastore()
+ self.serializer = self.hs.get_event_client_serializer()
+ self.clock = self.hs.get_clock()
+
+ def test_retention_event_purged_with_state_event(self):
+ """Tests that expired events are correctly purged when the room's retention policy
+ is defined by a state event.
"""
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ # Set the room's retention period to 2 days.
+ lifetime = one_day_ms * 2
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": one_day_ms * 4},
+ body={"max_lifetime": lifetime},
tok=self.token,
- expect_code=400,
)
+ self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+
+ def test_retention_event_purged_with_state_event_outside_allowed(self):
+ """Tests that the server configuration can override the policy for a room when
+ running the purge jobs.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ # Set a max_lifetime higher than the maximum allowed value.
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": one_hour_ms},
+ body={"max_lifetime": one_day_ms * 4},
tok=self.token,
- expect_code=400,
)
- def test_retention_event_purged_with_state_event(self):
- """Tests that expired events are correctly purged when the room's retention policy
- is defined by a state event.
- """
- room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ # Check that the event is purged after waiting for the maximum allowed duration
+ # instead of the one specified in the room's policy.
+ self._test_retention_event_purged(room_id, one_day_ms * 1.5)
- # Set the room's retention period to 2 days.
- lifetime = one_day_ms * 2
+ # Set a max_lifetime lower than the minimum allowed value.
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": lifetime},
+ body={"max_lifetime": one_hour_ms},
tok=self.token,
)
- self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+ # Check that the event is purged after waiting for the minimum allowed duration
+ # instead of the one specified in the room's policy.
+ self._test_retention_event_purged(room_id, one_day_ms * 0.5)
def test_retention_event_purged_without_state_event(self):
"""Tests that expired events are correctly purged when the room's retention policy
@@ -126,7 +139,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
events.append(self.get_success(store.get_event(valid_event_id)))
- # Advance the time by anothe 2 days. After this, the first event should be
+ # Advance the time by another 2 days. After this, the first event should be
# outdated but not the second one.
self.reactor.advance(one_day_ms * 2 / 1000)
@@ -140,11 +153,33 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# That event should be the second, not outdated event.
self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
- def _test_retention_event_purged(self, room_id, increment):
+ def _test_retention_event_purged(self, room_id: str, increment: float):
+ """Run the following test scenario to test the message retention policy support:
+
+ 1. Send event 1
+ 2. Increment time by `increment`
+ 3. Send event 2
+ 4. Increment time by `increment`
+ 5. Check that event 1 has been purged
+ 6. Check that event 2 has not been purged
+ 7. Check that state events that were sent before event 1 aren't purged.
+ The main reason for sending a second event is because currently Synapse won't
+ purge the latest message in a room because it would otherwise result in a lack of
+ forward extremities for this room. It's also a good thing to ensure the purge jobs
+ aren't too greedy and purge messages they shouldn't.
+
+ Args:
+ room_id: The ID of the room to test retention in.
+ increment: The number of milliseconds to advance the clock each time. Must be
+ defined so that events in the room aren't purged if they are `increment`
+ old but are purged if they are `increment * 2` old.
+ """
# Get the create event to, later, check that we can still access it.
message_handler = self.hs.get_message_handler()
create_event = self.get_success(
- message_handler.get_room_data(self.user_id, room_id, EventTypes.Create)
+ message_handler.get_room_data(
+ self.user_id, room_id, EventTypes.Create, state_key="", is_guest=False
+ )
)
# Send a first event to the room. This is the event we'll want to be purged at the
@@ -154,7 +189,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
expired_event_id = resp.get("event_id")
# Check that we can retrieve the event.
- expired_event = self.get_event(room_id, expired_event_id)
+ expired_event = self.get_event(expired_event_id)
self.assertEqual(
expired_event.get("content", {}).get("body"), "1", expired_event
)
@@ -172,26 +207,31 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# one should still be kept.
self.reactor.advance(increment / 1000)
- # Check that the event has been purged from the database.
- self.get_event(room_id, expired_event_id, expected_code=404)
+ # Check that the first event has been purged from the database, i.e. that we
+ # can't retrieve it anymore, because it has expired.
+ self.get_event(expired_event_id, expect_none=True)
- # Check that the event that hasn't been purged can still be retrieved.
- valid_event = self.get_event(room_id, valid_event_id)
+ # Check that the event that hasn't expired can still be retrieved.
+ valid_event = self.get_event(valid_event_id)
self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event)
# Check that we can still access state events that were sent before the event that
# has been purged.
self.get_event(room_id, create_event.event_id)
- def get_event(self, room_id, event_id, expected_code=200):
- url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+ def get_event(self, event_id, expect_none=False):
+ event = self.get_success(self.store.get_event(event_id, allow_none=True))
- request, channel = self.make_request("GET", url, access_token=self.token)
- self.render(request)
+ if expect_none:
+ self.assertIsNone(event)
+ return {}
- self.assertEqual(channel.code, expected_code, channel.result)
+ self.assertIsNotNone(event)
- return channel.json_body
+ time_now = self.clock.time_msec()
+ serialized = self.get_success(self.serializer.serialize_event(event, time_now))
+
+ return serialized
class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index c6d8f24fe9..3ddcca288b 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -62,8 +62,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -76,14 +75,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -111,8 +109,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -132,7 +129,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -160,8 +156,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -174,14 +169,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -396,7 +390,7 @@ class CASTestCase(unittest.HomeserverTestCase):
</cas:serviceResponse>
"""
% cas_user_id
- )
+ ).encode("utf-8")
mocked_http_client = Mock(spec=["get_raw"])
mocked_http_client.get_raw.side_effect = get_raw
@@ -512,19 +506,22 @@ class JWTTestCase(unittest.HomeserverTestCase):
]
jwt_secret = "secret"
+ jwt_algorithm = "HS256"
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
self.hs.config.jwt_enabled = True
self.hs.config.jwt_secret = self.jwt_secret
- self.hs.config.jwt_algorithm = "HS256"
+ self.hs.config.jwt_algorithm = self.jwt_algorithm
return self.hs
def jwt_encode(self, token, secret=jwt_secret):
- return jwt.encode(token, secret, "HS256").decode("ascii")
+ return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii")
def jwt_login(self, *args):
- params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
+ params = json.dumps(
+ {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
+ )
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
return channel
@@ -542,35 +539,126 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_invalid_signature(self):
channel = self.jwt_login({"sub": "frog"}, "notsecret")
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
- self.assertEqual(channel.json_body["error"], "Invalid JWT")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"],
+ "JWT validation failed: Signature verification failed",
+ )
def test_login_jwt_expired(self):
channel = self.jwt_login({"sub": "frog", "exp": 864000})
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
- self.assertEqual(channel.json_body["error"], "JWT expired")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"], "JWT validation failed: Signature has expired"
+ )
def test_login_jwt_not_before(self):
now = int(time.time())
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
- self.assertEqual(channel.json_body["error"], "Invalid JWT")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"],
+ "JWT validation failed: The token is not yet valid (nbf)",
+ )
def test_login_no_sub(self):
channel = self.jwt_login({"username": "root"})
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
+ @override_config(
+ {
+ "jwt_config": {
+ "jwt_enabled": True,
+ "secret": jwt_secret,
+ "algorithm": jwt_algorithm,
+ "issuer": "test-issuer",
+ }
+ }
+ )
+ def test_login_iss(self):
+ """Test validating the issuer claim."""
+ # A valid issuer.
+ channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ # An invalid issuer.
+ channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"], "JWT validation failed: Invalid issuer"
+ )
+
+ # Not providing an issuer.
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"],
+ 'JWT validation failed: Token is missing the "iss" claim',
+ )
+
+ def test_login_iss_no_config(self):
+ """Test providing an issuer claim without requiring it in the configuration."""
+ channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ @override_config(
+ {
+ "jwt_config": {
+ "jwt_enabled": True,
+ "secret": jwt_secret,
+ "algorithm": jwt_algorithm,
+ "audiences": ["test-audience"],
+ }
+ }
+ )
+ def test_login_aud(self):
+ """Test validating the audience claim."""
+ # A valid audience.
+ channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ # An invalid audience.
+ channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"], "JWT validation failed: Invalid audience"
+ )
+
+ # Not providing an audience.
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"],
+ 'JWT validation failed: Token is missing the "aud" claim',
+ )
+
+ def test_login_aud_no_config(self):
+ """Test providing an audience without requiring it in the configuration."""
+ channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"], "JWT validation failed: Invalid audience"
+ )
+
def test_login_no_token(self):
- params = json.dumps({"type": "m.login.jwt"})
+ params = json.dumps({"type": "org.matrix.login.jwt"})
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
@@ -638,7 +726,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
return jwt.encode(token, secret, "RS256").decode("ascii")
def jwt_login(self, *args):
- params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
+ params = json.dumps(
+ {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
+ )
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
return channel
@@ -650,6 +740,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
def test_login_jwt_invalid_signature(self):
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
- self.assertEqual(channel.json_body["error"], "Invalid JWT")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"],
+ "JWT validation failed: Signature verification failed",
+ )
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 0fdff79aa7..3c66255dac 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -60,7 +60,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def test_put_presence_disabled(self):
"""
- PUT to the status endpoint with use_presence disbled will NOT call
+ PUT to the status endpoint with use_presence disabled will NOT call
set_state on the presence handler.
"""
self.hs.config.use_presence = False
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 8df58b4a63..ace0a3c08d 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -70,8 +70,8 @@ class MockHandlerProfileTestCase(unittest.TestCase):
profile_handler=self.mock_handler,
)
- def _get_user_by_req(request=None, allow_guest=False):
- return defer.succeed(synapse.types.create_requester(myid))
+ async def _get_user_by_req(request=None, allow_guest=False):
+ return synapse.types.create_requester(myid)
hs.get_auth().get_user_by_req = _get_user_by_req
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 4886bbb401..c6c6edeac2 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -19,18 +19,16 @@
"""Tests REST events for /rooms paths."""
import json
+from urllib import parse as urlparse
-from mock import Mock
-from six.moves.urllib import parse as urlparse
-
-from twisted.internet import defer
+from mock import Mock, patch
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
from synapse.rest.client.v1 import directory, login, profile, room
-from synapse.rest.client.v2_alpha import account
-from synapse.types import JsonDict, RoomAlias
+from synapse.rest.client.v2_alpha import account, room_upgrade_rest_servlet
+from synapse.types import JsonDict, RoomAlias, UserID
from synapse.util.stringutils import random_string
from tests import unittest
@@ -51,8 +49,8 @@ class RoomBase(unittest.HomeserverTestCase):
self.hs.get_federation_handler = Mock(return_value=Mock())
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
+ async def _insert_client_ip(*args, **kwargs):
+ return None
self.hs.get_datastore().insert_client_ip = _insert_client_ip
@@ -677,6 +675,91 @@ class RoomMemberStateTestCase(RoomBase):
self.assertEquals(json.loads(content), channel.json_body)
+class RoomJoinRatelimitTestCase(RoomBase):
+ user_id = "@sid1:red"
+
+ servlets = [
+ profile.register_servlets,
+ room.register_servlets,
+ ]
+
+ @unittest.override_config(
+ {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
+ )
+ def test_join_local_ratelimit(self):
+ """Tests that local joins are actually rate-limited."""
+ for i in range(5):
+ self.helper.create_room_as(self.user_id)
+
+ self.helper.create_room_as(self.user_id, expect_code=429)
+
+ @unittest.override_config(
+ {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
+ )
+ def test_join_local_ratelimit_profile_change(self):
+ """Tests that sending a profile update into all of the user's joined rooms isn't
+ rate-limited by the rate-limiter on joins."""
+
+ # Create and join more rooms than the rate-limiting config allows in a second.
+ room_ids = [
+ self.helper.create_room_as(self.user_id),
+ self.helper.create_room_as(self.user_id),
+ self.helper.create_room_as(self.user_id),
+ ]
+ self.reactor.advance(1)
+ room_ids = room_ids + [
+ self.helper.create_room_as(self.user_id),
+ self.helper.create_room_as(self.user_id),
+ self.helper.create_room_as(self.user_id),
+ ]
+
+ # Create a profile for the user, since it hasn't been done on registration.
+ store = self.hs.get_datastore()
+ store.create_profile(UserID.from_string(self.user_id).localpart)
+
+ # Update the display name for the user.
+ path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
+ request, channel = self.make_request("PUT", path, {"displayname": "John Doe"})
+ self.render(request)
+ self.assertEquals(channel.code, 200, channel.json_body)
+
+ # Check that all the rooms have been sent a profile update into.
+ for room_id in room_ids:
+ path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (
+ room_id,
+ self.user_id,
+ )
+
+ request, channel = self.make_request("GET", path)
+ self.render(request)
+ self.assertEquals(channel.code, 200)
+
+ self.assertIn("displayname", channel.json_body)
+ self.assertEquals(channel.json_body["displayname"], "John Doe")
+
+ @unittest.override_config(
+ {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
+ )
+ def test_join_local_ratelimit_idempotent(self):
+ """Tests that the room join endpoints remain idempotent despite rate-limiting
+ on room joins."""
+ room_id = self.helper.create_room_as(self.user_id)
+
+ # Let's test both paths to be sure.
+ paths_to_test = [
+ "/_matrix/client/r0/rooms/%s/join",
+ "/_matrix/client/r0/join/%s",
+ ]
+
+ for path in paths_to_test:
+ # Make sure we send more requests than the rate-limiting config would allow
+ # if all of these requests ended up joining the user to a room.
+ for i in range(6):
+ request, channel = self.make_request("POST", path % room_id, {})
+ self.render(request)
+ self.assertEquals(channel.code, 200)
+
+
class RoomMessagesTestCase(RoomBase):
""" Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
@@ -1976,3 +2059,158 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
"""An alias which does not point to the room raises a SynapseError."""
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
+
+
+# To avoid the tests timing out don't add a delay to "annoy the requester".
+@patch("random.randint", new=lambda a, b: 0)
+class ShadowBannedTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ directory.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ room_upgrade_rest_servlet.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.banned_user_id = self.register_user("banned", "test")
+ self.banned_access_token = self.login("banned", "test")
+
+ self.store = self.hs.get_datastore()
+
+ self.get_success(
+ self.store.db_pool.simple_update(
+ table="users",
+ keyvalues={"name": self.banned_user_id},
+ updatevalues={"shadow_banned": True},
+ desc="shadow_ban",
+ )
+ )
+
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+ def test_invite(self):
+ """Invites from shadow-banned users don't actually get sent."""
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # Inviting the user completes successfully.
+ self.helper.invite(
+ room=room_id,
+ src=self.banned_user_id,
+ tok=self.banned_access_token,
+ targ=self.other_user_id,
+ )
+
+ # But the user wasn't actually invited.
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(self.other_user_id)
+ )
+ self.assertEqual(invited_rooms, [])
+
+ def test_invite_3pid(self):
+ """Ensure that a 3PID invite does not attempt to contact the identity server."""
+ identity_handler = self.hs.get_handlers().identity_handler
+ identity_handler.lookup_3pid = Mock(
+ side_effect=AssertionError("This should not get called")
+ )
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # Inviting the user completes successfully.
+ request, channel = self.make_request(
+ "POST",
+ "/rooms/%s/invite" % (room_id,),
+ {"id_server": "test", "medium": "email", "address": "test@test.test"},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ # This should have raised an error earlier, but double check this wasn't called.
+ identity_handler.lookup_3pid.assert_not_called()
+
+ def test_create_room(self):
+ """Invitations during a room creation should be discarded, but the room still gets created."""
+ # The room creation is successful.
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/createRoom",
+ {"visibility": "public", "invite": [self.other_user_id]},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ room_id = channel.json_body["room_id"]
+
+ # But the user wasn't actually invited.
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(self.other_user_id)
+ )
+ self.assertEqual(invited_rooms, [])
+
+ # Since a real room was created, the other user should be able to join it.
+ self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
+
+ # Both users should be in the room.
+ users = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
+
+ def test_message(self):
+ """Messages from shadow-banned users don't actually get sent."""
+
+ room_id = self.helper.create_room_as(
+ self.other_user_id, tok=self.other_access_token
+ )
+
+ # The user should be in the room.
+ self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token)
+
+ # Sending a message should complete successfully.
+ result = self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "with right label"},
+ tok=self.banned_access_token,
+ )
+ self.assertIn("event_id", result)
+ event_id = result["event_id"]
+
+ latest_events = self.get_success(
+ self.store.get_latest_event_ids_in_room(room_id)
+ )
+ self.assertNotIn(event_id, latest_events)
+
+ def test_upgrade(self):
+ """A room upgrade should fail, but look like it succeeded."""
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
+ {"new_version": "6"},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ # A new room_id should be returned.
+ self.assertIn("replacement_room", channel.json_body)
+
+ new_room_id = channel.json_body["replacement_room"]
+
+ # It doesn't really matter what API we use here, we just want to assert
+ # that the room doesn't exist.
+ summary = self.get_success(self.store.get_room_summary(new_room_id))
+ # The summary should be empty since the room doesn't exist.
+ self.assertEqual(summary, {})
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 18260bb90e..94d2bf2eb1 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -46,7 +46,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_handlers().federation_handler = Mock()
- def get_user_by_access_token(token=None, allow_guest=False):
+ async def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
@@ -55,8 +55,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_auth().get_user_by_access_token = get_user_by_access_token
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
+ async def _insert_client_ip(*args, **kwargs):
+ return None
hs.get_datastore().insert_client_ip = _insert_client_ip
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 22d734e763..e66c9a4c4c 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -39,7 +39,9 @@ class RestHelper(object):
resource = attr.ib()
auth_user_id = attr.ib()
- def create_room_as(self, room_creator=None, is_public=True, tok=None):
+ def create_room_as(
+ self, room_creator=None, is_public=True, tok=None, expect_code=200,
+ ):
temp_id = self.auth_user_id
self.auth_user_id = room_creator
path = "/_matrix/client/r0/createRoom"
@@ -54,9 +56,11 @@ class RestHelper(object):
)
render(request, self.resource, self.hs.get_reactor())
- assert channel.result["code"] == b"200", channel.result
+ assert channel.result["code"] == b"%d" % expect_code, channel.result
self.auth_user_id = temp_id
- return channel.json_body["room_id"]
+
+ if expect_code == 200:
+ return channel.json_body["room_id"]
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
self.change_membership(
@@ -88,7 +92,28 @@ class RestHelper(object):
expect_code=expect_code,
)
- def change_membership(self, room, src, targ, membership, tok=None, expect_code=200):
+ def change_membership(
+ self,
+ room: str,
+ src: str,
+ targ: str,
+ membership: str,
+ extra_data: dict = {},
+ tok: Optional[str] = None,
+ expect_code: int = 200,
+ ) -> None:
+ """
+ Send a membership state event into a room.
+
+ Args:
+ room: The ID of the room to send to
+ src: The mxid of the event sender
+ targ: The mxid of the event's target. The state key
+ membership: The type of membership event
+ extra_data: Extra information to include in the content of the event
+ tok: The user access token to use
+ expect_code: The expected HTTP response code
+ """
temp_id = self.auth_user_id
self.auth_user_id = src
@@ -97,6 +122,7 @@ class RestHelper(object):
path = path + "?access_token=%s" % tok
data = {"membership": membership}
+ data.update(extra_data)
request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8")
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 3ab611f618..152a5182fa 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -108,6 +108,46 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Assert we can't log in with the old password
self.attempt_wrong_password_login("kermit", old_password)
+ def test_basic_password_reset_canonicalise_email(self):
+ """Test basic password reset flow
+ Request password reset with different spelling
+ """
+ old_password = "monkey"
+ new_password = "kangeroo"
+
+ user_id = self.register_user("kermit", old_password)
+ self.login("kermit", old_password)
+
+ email_profile = "test@example.com"
+ email_passwort_reset = "TEST@EXAMPLE.COM"
+
+ # Add a threepid
+ self.get_success(
+ self.store.user_add_threepid(
+ user_id=user_id,
+ medium="email",
+ address=email_profile,
+ validated_at=0,
+ added_at=0,
+ )
+ )
+
+ client_secret = "foobar"
+ session_id = self._request_token(email_passwort_reset, client_secret)
+
+ self.assertEquals(len(self.email_attempts), 1)
+ link = self._get_link_from_email()
+
+ self._validate_token(link)
+
+ self._reset_password(new_password, session_id, client_secret)
+
+ # Assert we can log in with the new password
+ self.login("kermit", new_password)
+
+ # Assert we can't log in with the old password
+ self.attempt_wrong_password_login("kermit", old_password)
+
def test_cant_reset_password_without_clicking_link(self):
"""Test that we do actually need to click the link in the email
"""
@@ -386,44 +426,67 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.email = "test@example.com"
self.url_3pid = b"account/3pid"
- def test_add_email(self):
- """Test adding an email to profile
- """
- client_secret = "foobar"
- session_id = self._request_token(self.email, client_secret)
+ def test_add_valid_email(self):
+ self.get_success(self._add_email(self.email, self.email))
- self.assertEquals(len(self.email_attempts), 1)
- link = self._get_link_from_email()
+ def test_add_valid_email_second_time(self):
+ self.get_success(self._add_email(self.email, self.email))
+ self.get_success(
+ self._request_token_invalid_email(
+ self.email,
+ expected_errcode=Codes.THREEPID_IN_USE,
+ expected_error="Email is already in use",
+ )
+ )
- self._validate_token(link)
+ def test_add_valid_email_second_time_canonicalise(self):
+ self.get_success(self._add_email(self.email, self.email))
+ self.get_success(
+ self._request_token_invalid_email(
+ "TEST@EXAMPLE.COM",
+ expected_errcode=Codes.THREEPID_IN_USE,
+ expected_error="Email is already in use",
+ )
+ )
- request, channel = self.make_request(
- "POST",
- b"/_matrix/client/unstable/account/3pid/add",
- {
- "client_secret": client_secret,
- "sid": session_id,
- "auth": {
- "type": "m.login.password",
- "user": self.user_id,
- "password": "test",
- },
- },
- access_token=self.user_id_tok,
+ def test_add_email_no_at(self):
+ self.get_success(
+ self._request_token_invalid_email(
+ "address-without-at.bar",
+ expected_errcode=Codes.UNKNOWN,
+ expected_error="Unable to parse email address",
+ )
)
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ def test_add_email_two_at(self):
+ self.get_success(
+ self._request_token_invalid_email(
+ "foo@foo@test.bar",
+ expected_errcode=Codes.UNKNOWN,
+ expected_error="Unable to parse email address",
+ )
+ )
- # Get user
- request, channel = self.make_request(
- "GET", self.url_3pid, access_token=self.user_id_tok,
+ def test_add_email_bad_format(self):
+ self.get_success(
+ self._request_token_invalid_email(
+ "user@bad.example.net@good.example.com",
+ expected_errcode=Codes.UNKNOWN,
+ expected_error="Unable to parse email address",
+ )
)
- self.render(request)
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
+ def test_add_email_domain_to_lower(self):
+ self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar"))
+
+ def test_add_email_domain_with_umlaut(self):
+ self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com"))
+
+ def test_add_email_address_casefold(self):
+ self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com"))
+
+ def test_address_trim(self):
+ self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
def test_add_email_if_disabled(self):
"""Test adding email to profile when doing so is disallowed
@@ -616,6 +679,19 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
return channel.json_body["sid"]
+ def _request_token_invalid_email(
+ self, email, expected_errcode, expected_error, client_secret="foobar",
+ ):
+ request, channel = self.make_request(
+ "POST",
+ b"account/3pid/email/requestToken",
+ {"client_secret": client_secret, "email": email, "send_attempt": 1},
+ )
+ self.render(request)
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(expected_errcode, channel.json_body["errcode"])
+ self.assertEqual(expected_error, channel.json_body["error"])
+
def _validate_token(self, link):
# Remove the host
path = link.replace("https://example.com", "")
@@ -643,3 +719,42 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
assert match, "Could not find link in email"
return match.group(0)
+
+ def _add_email(self, request_email, expected_email):
+ """Test adding an email to profile
+ """
+ client_secret = "foobar"
+ session_id = self._request_token(request_email, client_secret)
+
+ self.assertEquals(len(self.email_attempts), 1)
+ link = self._get_link_from_email()
+
+ self._validate_token(link)
+
+ request, channel = self.make_request(
+ "POST",
+ b"/_matrix/client/unstable/account/3pid/add",
+ {
+ "client_secret": client_secret,
+ "sid": session_id,
+ "auth": {
+ "type": "m.login.password",
+ "user": self.user_id,
+ "password": "test",
+ },
+ },
+ access_token=self.user_id_tok,
+ )
+
+ self.render(request)
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Get user
+ request, channel = self.make_request(
+ "GET", self.url_3pid, access_token=self.user_id_tok,
+ )
+ self.render(request)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
+ self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"])
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 7e3dc22f64..0f33c7806d 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -116,8 +116,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
+ @override_config({"enable_registration": False})
def test_POST_disabled_registration(self):
- self.hs.config.enable_registration = False
request_data = json.dumps({"username": "kermit", "password": "monkey"})
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
@@ -160,7 +160,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
@@ -186,7 +186,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index c7e5859970..99c9f4e928 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -15,8 +15,7 @@
import itertools
import json
-
-import six
+import urllib
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
@@ -100,7 +99,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(400, channel.code, channel.json_body)
def test_basic_paginate_relations(self):
- """Tests that calling pagination API corectly the latest relations.
+ """Tests that calling pagination API correctly the latest relations.
"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
self.assertEquals(200, channel.code, channel.json_body)
@@ -134,7 +133,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
# Make sure next_batch has something in it that looks like it could be a
# valid token.
self.assertIsInstance(
- channel.json_body.get("next_batch"), six.string_types, channel.json_body
+ channel.json_body.get("next_batch"), str, channel.json_body
)
def test_repeated_paginate_relations(self):
@@ -278,7 +277,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
prev_token = None
found_event_ids = []
- encoded_key = six.moves.urllib.parse.quote_plus("👍".encode("utf-8"))
+ encoded_key = urllib.parse.quote_plus("👍".encode("utf-8"))
for _ in range(20):
from_token = ""
if prev_token:
@@ -670,7 +669,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
query = ""
if key:
- query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8"))
+ query = "?key=" + urllib.parse.quote_plus(key.encode("utf-8"))
original_id = parent_id if parent_id else self.parent_id
|