diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 5ccda8b2bd..0d809d25d5 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -23,14 +23,12 @@ from urllib import parse as urlparse
from mock import Mock
-from twisted.internet import defer
-
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
from synapse.rest.client.v1 import directory, login, profile, room
from synapse.rest.client.v2_alpha import account
-from synapse.types import JsonDict, RoomAlias
+from synapse.types import JsonDict, RoomAlias, UserID
from synapse.util.stringutils import random_string
from tests import unittest
@@ -51,8 +49,8 @@ class RoomBase(unittest.HomeserverTestCase):
self.hs.get_federation_handler = Mock(return_value=Mock())
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
+ async def _insert_client_ip(*args, **kwargs):
+ return None
self.hs.get_datastore().insert_client_ip = _insert_client_ip
@@ -677,6 +675,92 @@ class RoomMemberStateTestCase(RoomBase):
self.assertEquals(json.loads(content), channel.json_body)
+class RoomJoinRatelimitTestCase(RoomBase):
+ user_id = "@sid1:red"
+
+ servlets = [
+ profile.register_servlets,
+ room.register_servlets,
+ ]
+
+ @unittest.override_config(
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_join_local_ratelimit(self):
+ """Tests that local joins are actually rate-limited."""
+ for i in range(3):
+ self.helper.create_room_as(self.user_id)
+
+ self.helper.create_room_as(self.user_id, expect_code=429)
+
+ @unittest.override_config(
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_join_local_ratelimit_profile_change(self):
+ """Tests that sending a profile update into all of the user's joined rooms isn't
+ rate-limited by the rate-limiter on joins."""
+
+ # Create and join as many rooms as the rate-limiting config allows in a second.
+ room_ids = [
+ self.helper.create_room_as(self.user_id),
+ self.helper.create_room_as(self.user_id),
+ self.helper.create_room_as(self.user_id),
+ ]
+ # Let some time for the rate-limiter to forget about our multi-join.
+ self.reactor.advance(2)
+ # Add one to make sure we're joined to more rooms than the config allows us to
+ # join in a second.
+ room_ids.append(self.helper.create_room_as(self.user_id))
+
+ # Create a profile for the user, since it hasn't been done on registration.
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.create_profile(UserID.from_string(self.user_id).localpart)
+ )
+
+ # Update the display name for the user.
+ path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
+ request, channel = self.make_request("PUT", path, {"displayname": "John Doe"})
+ self.render(request)
+ self.assertEquals(channel.code, 200, channel.json_body)
+
+ # Check that all the rooms have been sent a profile update into.
+ for room_id in room_ids:
+ path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (
+ room_id,
+ self.user_id,
+ )
+
+ request, channel = self.make_request("GET", path)
+ self.render(request)
+ self.assertEquals(channel.code, 200)
+
+ self.assertIn("displayname", channel.json_body)
+ self.assertEquals(channel.json_body["displayname"], "John Doe")
+
+ @unittest.override_config(
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_join_local_ratelimit_idempotent(self):
+ """Tests that the room join endpoints remain idempotent despite rate-limiting
+ on room joins."""
+ room_id = self.helper.create_room_as(self.user_id)
+
+ # Let's test both paths to be sure.
+ paths_to_test = [
+ "/_matrix/client/r0/rooms/%s/join",
+ "/_matrix/client/r0/join/%s",
+ ]
+
+ for path in paths_to_test:
+ # Make sure we send more requests than the rate-limiting config would allow
+ # if all of these requests ended up joining the user to a room.
+ for i in range(4):
+ request, channel = self.make_request("POST", path % room_id, {})
+ self.render(request)
+ self.assertEquals(channel.code, 200)
+
+
class RoomMessagesTestCase(RoomBase):
""" Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
@@ -821,6 +905,7 @@ class RoomMessageListTestCase(RoomBase):
first_token = self.get_success(
store.get_topological_token_for_event(first_event_id)
)
+ first_token_str = self.get_success(first_token.to_string(store))
# Send a second message in the room, which won't be removed, and which we'll
# use as the marker to purge events before.
@@ -828,6 +913,7 @@ class RoomMessageListTestCase(RoomBase):
second_token = self.get_success(
store.get_topological_token_for_event(second_event_id)
)
+ second_token_str = self.get_success(second_token.to_string(store))
# Send a third event in the room to ensure we don't fall under any edge case
# due to our marker being the latest forward extremity in the room.
@@ -837,7 +923,11 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
- % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
+ % (
+ self.room_id,
+ second_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -852,7 +942,7 @@ class RoomMessageListTestCase(RoomBase):
pagination_handler._purge_history(
purge_id=purge_id,
room_id=self.room_id,
- token=second_token,
+ token=second_token_str,
delete_local_events=True,
)
)
@@ -862,7 +952,11 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
- % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
+ % (
+ self.room_id,
+ second_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -876,7 +970,11 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
- % (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})),
+ % (
+ self.room_id,
+ first_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
|