From 2c3548d9d89ed4c8cefd5b18d1b86ff0fc2f52bf Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 4 Mar 2019 10:05:39 +0000 Subject: Update test_typing to use HomeserverTestCase. (#4771) --- tests/handlers/test_typing.py | 290 +++++++++++++++++++----------------------- 1 file changed, 133 insertions(+), 157 deletions(-) (limited to 'tests/handlers') diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 36e136cded..13486930fb 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -24,13 +24,17 @@ from synapse.api.errors import AuthError from synapse.types import UserID from tests import unittest +from tests.utils import register_federation_servlets -from ..utils import ( - DeferredMockCallable, - MockClock, - MockHttpResource, - setup_test_homeserver, -) +# Some local users to test with +U_APPLE = UserID.from_string("@apple:test") +U_BANANA = UserID.from_string("@banana:test") + +# Remote user +U_ONION = UserID.from_string("@onion:farm") + +# Test room id +ROOM_ID = "a-room" def _expect_edu_transaction(edu_type, content, origin="test"): @@ -46,30 +50,21 @@ def _make_edu_transaction_json(edu_type, content): return json.dumps(_expect_edu_transaction(edu_type, content)).encode('utf8') -class TypingNotificationsTestCase(unittest.TestCase): - """Tests typing notifications to rooms.""" - - @defer.inlineCallbacks - def setUp(self): - self.clock = MockClock() +class TypingNotificationsTestCase(unittest.HomeserverTestCase): + servlets = [register_federation_servlets] - self.mock_http_client = Mock(spec=[]) - self.mock_http_client.put_json = DeferredMockCallable() + def make_homeserver(self, reactor, clock): + # we mock out the keyring so as to skip the authentication check on the + # federation API call. + mock_keyring = Mock(spec=["verify_json_for_server"]) + mock_keyring.verify_json_for_server.return_value = defer.succeed(True) - self.mock_federation_resource = MockHttpResource() - - mock_notifier = Mock() - self.on_new_event = mock_notifier.on_new_event + # we mock out the federation client too + mock_federation_client = Mock(spec=["put_json"]) + mock_federation_client.put_json.return_value = defer.succeed((200, "OK")) - self.auth = Mock(spec=[]) - self.state_handler = Mock() - - hs = yield setup_test_homeserver( - self.addCleanup, - "test", - auth=self.auth, - clock=self.clock, - datastore=Mock( + hs = self.setup_test_homeserver( + datastore=(Mock( spec=[ # Bits that Federation needs "prep_send_transaction", @@ -82,16 +77,21 @@ class TypingNotificationsTestCase(unittest.TestCase): "get_user_directory_stream_pos", "get_current_state_deltas", ] - ), - state_handler=self.state_handler, - handlers=Mock(), - notifier=mock_notifier, - resource_for_client=Mock(), - resource_for_federation=self.mock_federation_resource, - http_client=self.mock_http_client, - keyring=Mock(), + )), + notifier=Mock(), + http_client=mock_federation_client, + keyring=mock_keyring, ) + return hs + + def prepare(self, reactor, clock, hs): + # the tests assume that we are starting at unix time 1000 + reactor.pump((1000, )) + + mock_notifier = hs.get_notifier() + self.on_new_event = mock_notifier.on_new_event + self.handler = hs.get_typing_handler() self.event_source = hs.get_event_sources().sources["typing"] @@ -109,13 +109,12 @@ class TypingNotificationsTestCase(unittest.TestCase): self.datastore.get_received_txn_response = get_received_txn_response - self.room_id = "a-room" - self.room_members = [] def check_joined_room(room_id, user_id): if user_id not in [u.to_string() for u in self.room_members]: raise AuthError(401, "User is not in the room") + hs.get_auth().check_joined_room = check_joined_room def get_joined_hosts_for_room(room_id): return set(member.domain for member in self.room_members) @@ -124,8 +123,7 @@ class TypingNotificationsTestCase(unittest.TestCase): def get_current_user_in_room(room_id): return set(str(u) for u in self.room_members) - - self.state_handler.get_current_user_in_room = get_current_user_in_room + hs.get_state_handler().get_current_user_in_room = get_current_user_in_room self.datastore.get_user_directory_stream_pos.return_value = ( # we deliberately return a non-None stream pos to avoid doing an initial_spam @@ -134,230 +132,208 @@ class TypingNotificationsTestCase(unittest.TestCase): self.datastore.get_current_state_deltas.return_value = None - self.auth.check_joined_room = check_joined_room - self.datastore.get_to_device_stream_token = lambda: 0 self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None - # Some local users to test with - self.u_apple = UserID.from_string("@apple:test") - self.u_banana = UserID.from_string("@banana:test") - - # Remote user - self.u_onion = UserID.from_string("@onion:farm") - - @defer.inlineCallbacks def test_started_typing_local(self): - self.room_members = [self.u_apple, self.u_banana] + self.room_members = [U_APPLE, U_BANANA] self.assertEquals(self.event_source.get_current_key(), 0) - yield self.handler.started_typing( - target_user=self.u_apple, - auth_user=self.u_apple, - room_id=self.room_id, + self.successResultOf(self.handler.started_typing( + target_user=U_APPLE, + auth_user=U_APPLE, + room_id=ROOM_ID, timeout=20000, - ) + )) self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[self.room_id])] + [call('typing_key', 1, rooms=[ROOM_ID])] ) self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events( - room_ids=[self.room_id], from_key=0 + events = self.event_source.get_new_events( + room_ids=[ROOM_ID], from_key=0 ) self.assertEquals( events[0], [ { "type": "m.typing", - "room_id": self.room_id, - "content": {"user_ids": [self.u_apple.to_string()]}, + "room_id": ROOM_ID, + "content": {"user_ids": [U_APPLE.to_string()]}, } ], ) - @defer.inlineCallbacks def test_started_typing_remote_send(self): - self.room_members = [self.u_apple, self.u_onion] - - put_json = self.mock_http_client.put_json - put_json.expect_call_and_return( - call( - "farm", - path="/_matrix/federation/v1/send/1000000/", - data=_expect_edu_transaction( - "m.typing", - content={ - "room_id": self.room_id, - "user_id": self.u_apple.to_string(), - "typing": True, - }, - ), - json_data_callback=ANY, - long_retries=True, - backoff_on_404=True, - ), - defer.succeed((200, "OK")), - ) + self.room_members = [U_APPLE, U_ONION] - yield self.handler.started_typing( - target_user=self.u_apple, - auth_user=self.u_apple, - room_id=self.room_id, + self.successResultOf(self.handler.started_typing( + target_user=U_APPLE, + auth_user=U_APPLE, + room_id=ROOM_ID, timeout=20000, - ) + )) - yield put_json.await_calls() + put_json = self.hs.get_http_client().put_json + put_json.assert_called_once_with( + "farm", + path="/_matrix/federation/v1/send/1000000/", + data=_expect_edu_transaction( + "m.typing", + content={ + "room_id": ROOM_ID, + "user_id": U_APPLE.to_string(), + "typing": True, + }, + ), + json_data_callback=ANY, + long_retries=True, + backoff_on_404=True, + ) - @defer.inlineCallbacks def test_started_typing_remote_recv(self): - self.room_members = [self.u_apple, self.u_onion] + self.room_members = [U_APPLE, U_ONION] self.assertEquals(self.event_source.get_current_key(), 0) - (code, response) = yield self.mock_federation_resource.trigger( + (request, channel) = self.make_request( "PUT", "/_matrix/federation/v1/send/1000000/", _make_edu_transaction_json( "m.typing", content={ - "room_id": self.room_id, - "user_id": self.u_onion.to_string(), + "room_id": ROOM_ID, + "user_id": U_ONION.to_string(), "typing": True, }, ), federation_auth_origin=b'farm', ) + self.render(request) + self.assertEqual(channel.code, 200) self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[self.room_id])] + [call('typing_key', 1, rooms=[ROOM_ID])] ) self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events( - room_ids=[self.room_id], from_key=0 + events = self.event_source.get_new_events( + room_ids=[ROOM_ID], from_key=0 ) self.assertEquals( events[0], [ { "type": "m.typing", - "room_id": self.room_id, - "content": {"user_ids": [self.u_onion.to_string()]}, + "room_id": ROOM_ID, + "content": {"user_ids": [U_ONION.to_string()]}, } ], ) - @defer.inlineCallbacks def test_stopped_typing(self): - self.room_members = [self.u_apple, self.u_banana, self.u_onion] - - put_json = self.mock_http_client.put_json - put_json.expect_call_and_return( - call( - "farm", - path="/_matrix/federation/v1/send/1000000/", - data=_expect_edu_transaction( - "m.typing", - content={ - "room_id": self.room_id, - "user_id": self.u_apple.to_string(), - "typing": False, - }, - ), - json_data_callback=ANY, - long_retries=True, - backoff_on_404=True, - ), - defer.succeed((200, "OK")), - ) + self.room_members = [U_APPLE, U_BANANA, U_ONION] # Gut-wrenching from synapse.handlers.typing import RoomMember - member = RoomMember(self.room_id, self.u_apple.to_string()) + member = RoomMember(ROOM_ID, U_APPLE.to_string()) self.handler._member_typing_until[member] = 1002000 - self.handler._room_typing[self.room_id] = set([self.u_apple.to_string()]) + self.handler._room_typing[ROOM_ID] = set([U_APPLE.to_string()]) self.assertEquals(self.event_source.get_current_key(), 0) - yield self.handler.stopped_typing( - target_user=self.u_apple, auth_user=self.u_apple, room_id=self.room_id - ) + self.successResultOf(self.handler.stopped_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID + )) self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[self.room_id])] + [call('typing_key', 1, rooms=[ROOM_ID])] ) - yield put_json.await_calls() + put_json = self.hs.get_http_client().put_json + put_json.assert_called_once_with( + "farm", + path="/_matrix/federation/v1/send/1000000/", + data=_expect_edu_transaction( + "m.typing", + content={ + "room_id": ROOM_ID, + "user_id": U_APPLE.to_string(), + "typing": False, + }, + ), + json_data_callback=ANY, + long_retries=True, + backoff_on_404=True, + ) self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events( - room_ids=[self.room_id], from_key=0 + events = self.event_source.get_new_events( + room_ids=[ROOM_ID], from_key=0 ) self.assertEquals( events[0], [ { "type": "m.typing", - "room_id": self.room_id, + "room_id": ROOM_ID, "content": {"user_ids": []}, } ], ) - @defer.inlineCallbacks def test_typing_timeout(self): - self.room_members = [self.u_apple, self.u_banana] + self.room_members = [U_APPLE, U_BANANA] self.assertEquals(self.event_source.get_current_key(), 0) - yield self.handler.started_typing( - target_user=self.u_apple, - auth_user=self.u_apple, - room_id=self.room_id, + self.successResultOf(self.handler.started_typing( + target_user=U_APPLE, + auth_user=U_APPLE, + room_id=ROOM_ID, timeout=10000, - ) + )) self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[self.room_id])] + [call('typing_key', 1, rooms=[ROOM_ID])] ) self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 1) - events = yield self.event_source.get_new_events( - room_ids=[self.room_id], from_key=0 + events = self.event_source.get_new_events( + room_ids=[ROOM_ID], from_key=0 ) self.assertEquals( events[0], [ { "type": "m.typing", - "room_id": self.room_id, - "content": {"user_ids": [self.u_apple.to_string()]}, + "room_id": ROOM_ID, + "content": {"user_ids": [U_APPLE.to_string()]}, } ], ) - self.clock.advance_time(16) + self.reactor.pump([16, ]) self.on_new_event.assert_has_calls( - [call('typing_key', 2, rooms=[self.room_id])] + [call('typing_key', 2, rooms=[ROOM_ID])] ) self.assertEquals(self.event_source.get_current_key(), 2) - events = yield self.event_source.get_new_events( - room_ids=[self.room_id], from_key=1 + events = self.event_source.get_new_events( + room_ids=[ROOM_ID], from_key=1 ) self.assertEquals( events[0], [ { "type": "m.typing", - "room_id": self.room_id, + "room_id": ROOM_ID, "content": {"user_ids": []}, } ], @@ -365,29 +341,29 @@ class TypingNotificationsTestCase(unittest.TestCase): # SYN-230 - see if we can still set after timeout - yield self.handler.started_typing( - target_user=self.u_apple, - auth_user=self.u_apple, - room_id=self.room_id, + self.successResultOf(self.handler.started_typing( + target_user=U_APPLE, + auth_user=U_APPLE, + room_id=ROOM_ID, timeout=10000, - ) + )) self.on_new_event.assert_has_calls( - [call('typing_key', 3, rooms=[self.room_id])] + [call('typing_key', 3, rooms=[ROOM_ID])] ) self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 3) - events = yield self.event_source.get_new_events( - room_ids=[self.room_id], from_key=0 + events = self.event_source.get_new_events( + room_ids=[ROOM_ID], from_key=0 ) self.assertEquals( events[0], [ { "type": "m.typing", - "room_id": self.room_id, - "content": {"user_ids": [self.u_apple.to_string()]}, + "room_id": ROOM_ID, + "content": {"user_ids": [U_APPLE.to_string()]}, } ], ) -- cgit 1.5.1 From a4c3a361b70bc02d65104240bef1b3cbb110bf22 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 5 Mar 2019 14:25:33 +0000 Subject: Add rate-limiting on registration (#4735) * Rate-limiting for registration * Add unit test for registration rate limiting * Add config parameters for rate limiting on auth endpoints * Doc * Fix doc of rate limiting function Co-Authored-By: babolivier * Incorporate review * Fix config parsing * Fix linting errors * Set default config for auth rate limiting * Fix tests * Add changelog * Advance reactor instead of mocked clock * Move parameters to registration specific config and give them more sensible default values * Remove unused config options * Don't mock the rate limiter un MAU tests * Rename _register_with_store into register_with_store * Make CI happy * Remove unused import * Update sample config * Fix ratelimiting test for py2 * Add non-guest test --- changelog.d/4735.feature | 1 + docs/sample_config.yaml | 11 +++++++ synapse/api/ratelimiting.py | 31 ++++++++++--------- synapse/config/registration.py | 18 +++++++++++ synapse/handlers/_base.py | 4 +-- synapse/handlers/register.py | 39 ++++++++++++++++++----- synapse/replication/http/register.py | 8 +++-- synapse/rest/client/v2_alpha/register.py | 33 +++++++++++++++++--- tests/api/test_ratelimiting.py | 20 ++++++------ tests/handlers/test_profile.py | 4 +-- tests/replication/slave/storage/_base.py | 4 +-- tests/rest/client/v1/test_events.py | 4 +-- tests/rest/client/v1/test_rooms.py | 6 ++-- tests/rest/client/v1/test_typing.py | 4 +-- tests/rest/client/v2_alpha/test_register.py | 48 +++++++++++++++++++++++++++++ tests/test_mau.py | 3 +- tests/utils.py | 2 ++ 17 files changed, 186 insertions(+), 54 deletions(-) create mode 100644 changelog.d/4735.feature (limited to 'tests/handlers') diff --git a/changelog.d/4735.feature b/changelog.d/4735.feature new file mode 100644 index 0000000000..a4c0b196f6 --- /dev/null +++ b/changelog.d/4735.feature @@ -0,0 +1 @@ +Add configurable rate limiting to the /register endpoint. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 7cf58d2182..e0140003fd 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -657,6 +657,17 @@ trusted_third_party_id_servers: # autocreate_auto_join_rooms: true +# Number of registration requests a client can send per second. +# Defaults to 1/minute (0.17). +# +#rc_registration_requests_per_second: 0.17 + +# Number of registration requests a client can send before being +# throttled. +# Defaults to 3. +# +#rc_registration_request_burst_count: 3.0 + ## Metrics ### diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 3bb5b3da37..ad68079eeb 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -23,12 +23,13 @@ class Ratelimiter(object): def __init__(self): self.message_counts = collections.OrderedDict() - def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True): - """Can the user send a message? + def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True): + """Can the entity (e.g. user or IP address) perform the action? Args: - user_id: The user sending a message. + key: The key we should use when rate limiting. Can be a user ID + (when sending events), an IP address, etc. time_now_s: The time now. - msg_rate_hz: The long term number of messages a user can send in a + rate_hz: The long term number of messages a user can send in a second. burst_count: How many messages the user can send before being limited. @@ -41,10 +42,10 @@ class Ratelimiter(object): """ self.prune_message_counts(time_now_s) message_count, time_start, _ignored = self.message_counts.get( - user_id, (0., time_now_s, None), + key, (0., time_now_s, None), ) time_delta = time_now_s - time_start - sent_count = message_count - time_delta * msg_rate_hz + sent_count = message_count - time_delta * rate_hz if sent_count < 0: allowed = True time_start = time_now_s @@ -56,13 +57,13 @@ class Ratelimiter(object): message_count += 1 if update: - self.message_counts[user_id] = ( - message_count, time_start, msg_rate_hz + self.message_counts[key] = ( + message_count, time_start, rate_hz ) - if msg_rate_hz > 0: + if rate_hz > 0: time_allowed = ( - time_start + (message_count - burst_count + 1) / msg_rate_hz + time_start + (message_count - burst_count + 1) / rate_hz ) if time_allowed < time_now_s: time_allowed = time_now_s @@ -72,12 +73,12 @@ class Ratelimiter(object): return allowed, time_allowed def prune_message_counts(self, time_now_s): - for user_id in list(self.message_counts.keys()): - message_count, time_start, msg_rate_hz = ( - self.message_counts[user_id] + for key in list(self.message_counts.keys()): + message_count, time_start, rate_hz = ( + self.message_counts[key] ) time_delta = time_now_s - time_start - if message_count - time_delta * msg_rate_hz > 0: + if message_count - time_delta * rate_hz > 0: break else: - del self.message_counts[user_id] + del self.message_counts[key] diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 2881482f96..d32f6fff73 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -54,6 +54,13 @@ class RegistrationConfig(Config): config.get("disable_msisdn_registration", False) ) + self.rc_registration_requests_per_second = config.get( + "rc_registration_requests_per_second", 0.17, + ) + self.rc_registration_request_burst_count = config.get( + "rc_registration_request_burst_count", 3, + ) + def default_config(self, generate_secrets=False, **kwargs): if generate_secrets: registration_shared_secret = 'registration_shared_secret: "%s"' % ( @@ -140,6 +147,17 @@ class RegistrationConfig(Config): # users cannot be auto-joined since they do not exist. # autocreate_auto_join_rooms: true + + # Number of registration requests a client can send per second. + # Defaults to 1/minute (0.17). + # + #rc_registration_requests_per_second: 0.17 + + # Number of registration requests a client can send before being + # throttled. + # Defaults to 3. + # + #rc_registration_request_burst_count: 3.0 """ % locals() def add_arguments(self, parser): diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 594754cfd8..d8d86d6ff3 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -93,9 +93,9 @@ class BaseHandler(object): messages_per_second = self.hs.config.rc_messages_per_second burst_count = self.hs.config.rc_message_burst_count - allowed, time_allowed = self.ratelimiter.send_message( + allowed, time_allowed = self.ratelimiter.can_do_action( user_id, time_now, - msg_rate_hz=messages_per_second, + rate_hz=messages_per_second, burst_count=burst_count, update=update, ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c0e06929bd..47d5e276f8 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -24,6 +24,7 @@ from synapse.api.errors import ( AuthError, Codes, InvalidCaptchaError, + LimitExceededError, RegistrationError, SynapseError, ) @@ -60,6 +61,7 @@ class RegistrationHandler(BaseHandler): self.user_directory_handler = hs.get_user_directory_handler() self.captcha_client = CaptchaServerHttpClient(hs) self.identity_handler = self.hs.get_handlers().identity_handler + self.ratelimiter = hs.get_ratelimiter() self._next_generated_user_id = None @@ -149,6 +151,7 @@ class RegistrationHandler(BaseHandler): threepid=None, user_type=None, default_display_name=None, + address=None, ): """Registers a new client on the server. @@ -167,6 +170,7 @@ class RegistrationHandler(BaseHandler): api.constants.UserTypes, or None for a normal user. default_display_name (unicode|None): if set, the new user's displayname will be set to this. Defaults to 'localpart'. + address (str|None): the IP address used to perform the regitration. Returns: A tuple of (user_id, access_token). Raises: @@ -206,7 +210,7 @@ class RegistrationHandler(BaseHandler): token = None if generate_token: token = self.macaroon_gen.generate_access_token(user_id) - yield self._register_with_store( + yield self.register_with_store( user_id=user_id, token=token, password_hash=password_hash, @@ -215,6 +219,7 @@ class RegistrationHandler(BaseHandler): create_profile_with_displayname=default_display_name, admin=admin, user_type=user_type, + address=address, ) if self.hs.config.user_directory_search_all_users: @@ -238,12 +243,13 @@ class RegistrationHandler(BaseHandler): if default_display_name is None: default_display_name = localpart try: - yield self._register_with_store( + yield self.register_with_store( user_id=user_id, token=token, password_hash=password_hash, make_guest=make_guest, create_profile_with_displayname=default_display_name, + address=address, ) except SynapseError: # if user id is taken, just generate another @@ -337,7 +343,7 @@ class RegistrationHandler(BaseHandler): user_id, allowed_appservice=service ) - yield self._register_with_store( + yield self.register_with_store( user_id=user_id, password_hash="", appservice_id=service_id, @@ -513,7 +519,7 @@ class RegistrationHandler(BaseHandler): token = self.macaroon_gen.generate_access_token(user_id) if need_register: - yield self._register_with_store( + yield self.register_with_store( user_id=user_id, token=token, password_hash=password_hash, @@ -590,10 +596,10 @@ class RegistrationHandler(BaseHandler): ratelimit=False, ) - def _register_with_store(self, user_id, token=None, password_hash=None, - was_guest=False, make_guest=False, appservice_id=None, - create_profile_with_displayname=None, admin=False, - user_type=None): + def register_with_store(self, user_id, token=None, password_hash=None, + was_guest=False, make_guest=False, appservice_id=None, + create_profile_with_displayname=None, admin=False, + user_type=None, address=None): """Register user in the datastore. Args: @@ -612,10 +618,26 @@ class RegistrationHandler(BaseHandler): admin (boolean): is an admin user? user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. + address (str|None): the IP address used to perform the regitration. Returns: Deferred """ + # Don't rate limit for app services + if appservice_id is None and address is not None: + time_now = self.clock.time() + + allowed, time_allowed = self.ratelimiter.can_do_action( + address, time_now_s=time_now, + rate_hz=self.hs.config.rc_registration_requests_per_second, + burst_count=self.hs.config.rc_registration_request_burst_count, + ) + + if not allowed: + raise LimitExceededError( + retry_after_ms=int(1000 * (time_allowed - time_now)), + ) + if self.hs.config.worker_app: return self._register_client( user_id=user_id, @@ -627,6 +649,7 @@ class RegistrationHandler(BaseHandler): create_profile_with_displayname=create_profile_with_displayname, admin=admin, user_type=user_type, + address=address, ) else: return self.store.register( diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 1d27c9221f..912a5ac341 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -33,11 +33,12 @@ class ReplicationRegisterServlet(ReplicationEndpoint): def __init__(self, hs): super(ReplicationRegisterServlet, self).__init__(hs) self.store = hs.get_datastore() + self.registration_handler = hs.get_registration_handler() @staticmethod def _serialize_payload( user_id, token, password_hash, was_guest, make_guest, appservice_id, - create_profile_with_displayname, admin, user_type, + create_profile_with_displayname, admin, user_type, address, ): """ Args: @@ -56,6 +57,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): admin (boolean): is an admin user? user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. + address (str|None): the IP address used to perform the regitration. """ return { "token": token, @@ -66,13 +68,14 @@ class ReplicationRegisterServlet(ReplicationEndpoint): "create_profile_with_displayname": create_profile_with_displayname, "admin": admin, "user_type": user_type, + "address": address, } @defer.inlineCallbacks def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) - yield self.store.register( + yield self.registration_handler.register_with_store( user_id=user_id, token=content["token"], password_hash=content["password_hash"], @@ -82,6 +85,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): create_profile_with_displayname=content["create_profile_with_displayname"], admin=content["admin"], user_type=content["user_type"], + address=content["address"] ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 94cbba4303..b7f354570c 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -25,7 +25,12 @@ from twisted.internet import defer import synapse import synapse.types from synapse.api.constants import LoginType -from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError +from synapse.api.errors import ( + Codes, + LimitExceededError, + SynapseError, + UnrecognizedRequestError, +) from synapse.config.server import is_threepid_reserved from synapse.http.servlet import ( RestServlet, @@ -191,18 +196,36 @@ class RegisterRestServlet(RestServlet): self.identity_handler = hs.get_handlers().identity_handler self.room_member_handler = hs.get_room_member_handler() self.macaroon_gen = hs.get_macaroon_generator() + self.ratelimiter = hs.get_ratelimiter() + self.clock = hs.get_clock() @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) + client_addr = request.getClientIP() + + time_now = self.clock.time() + + allowed, time_allowed = self.ratelimiter.can_do_action( + client_addr, time_now_s=time_now, + rate_hz=self.hs.config.rc_registration_requests_per_second, + burst_count=self.hs.config.rc_registration_request_burst_count, + update=False, + ) + + if not allowed: + raise LimitExceededError( + retry_after_ms=int(1000 * (time_allowed - time_now)), + ) + kind = b"user" if b"kind" in request.args: kind = request.args[b"kind"][0] if kind == b"guest": - ret = yield self._do_guest_registration(body) + ret = yield self._do_guest_registration(body, address=client_addr) defer.returnValue(ret) return elif kind != b"user": @@ -411,6 +434,7 @@ class RegisterRestServlet(RestServlet): guest_access_token=guest_access_token, generate_token=False, threepid=threepid, + address=client_addr, ) # Necessary due to auth checks prior to the threepid being # written to the db @@ -522,12 +546,13 @@ class RegisterRestServlet(RestServlet): defer.returnValue(result) @defer.inlineCallbacks - def _do_guest_registration(self, params): + def _do_guest_registration(self, params, address=None): if not self.hs.config.allow_guest_access: raise SynapseError(403, "Guest access is disabled") user_id, _ = yield self.registration_handler.register( generate_token=False, - make_guest=True + make_guest=True, + address=address, ) # we don't allow guests to specify their own device_id, because diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 8933fe3b72..30a255d441 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -6,34 +6,34 @@ from tests import unittest class TestRatelimiter(unittest.TestCase): def test_allowed(self): limiter = Ratelimiter() - allowed, time_allowed = limiter.send_message( - user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1 + allowed, time_allowed = limiter.can_do_action( + key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1 ) self.assertTrue(allowed) self.assertEquals(10., time_allowed) - allowed, time_allowed = limiter.send_message( - user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1 + allowed, time_allowed = limiter.can_do_action( + key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1 ) self.assertFalse(allowed) self.assertEquals(10., time_allowed) - allowed, time_allowed = limiter.send_message( - user_id="test_id", time_now_s=10, msg_rate_hz=0.1, burst_count=1 + allowed, time_allowed = limiter.can_do_action( + key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1 ) self.assertTrue(allowed) self.assertEquals(20., time_allowed) def test_pruning(self): limiter = Ratelimiter() - allowed, time_allowed = limiter.send_message( - user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1 + allowed, time_allowed = limiter.can_do_action( + key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1 ) self.assertIn("test_id_1", limiter.message_counts) - allowed, time_allowed = limiter.send_message( - user_id="test_id_2", time_now_s=10, msg_rate_hz=0.1, burst_count=1 + allowed, time_allowed = limiter.can_do_action( + key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1 ) self.assertNotIn("test_id_1", limiter.message_counts) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 80da1c8954..d60c124eec 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -55,11 +55,11 @@ class ProfileTestCase(unittest.TestCase): federation_client=self.mock_federation, federation_server=Mock(), federation_registry=self.mock_registry, - ratelimiter=NonCallableMock(spec_set=["send_message"]), + ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) + self.ratelimiter.can_do_action.return_value = (True, 0) self.store = hs.get_datastore() diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 9e9fbbfe93..524af4f8d1 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -31,10 +31,10 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver( "blue", federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), + ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) - hs.get_ratelimiter().send_message.return_value = (True, 0) + hs.get_ratelimiter().can_do_action.return_value = (True, 0) return hs diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 483bebc832..36d8547275 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -40,10 +40,10 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): config.auto_join_rooms = [] hs = self.setup_test_homeserver( - config=config, ratelimiter=NonCallableMock(spec_set=["send_message"]) + config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"]) ) self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) + self.ratelimiter.can_do_action.return_value = (True, 0) hs.get_handlers().federation_handler = Mock() diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index a824be9a62..015c144248 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -41,10 +41,10 @@ class RoomBase(unittest.HomeserverTestCase): "red", http_client=None, federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), + ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) self.ratelimiter = self.hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) + self.ratelimiter.can_do_action.return_value = (True, 0) self.hs.get_federation_handler = Mock(return_value=Mock()) @@ -96,7 +96,7 @@ class RoomPermissionsTestCase(RoomBase): # auth as user_id now self.helper.auth_user_id = self.user_id - def test_send_message(self): + def test_can_do_action(self): msg_content = b'{"msgtype":"m.text","body":"hello"}' seq = iter(range(100)) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 0ad814c5e5..30fb77bac8 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -42,13 +42,13 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): "red", http_client=None, federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), + ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) self.event_source = hs.get_event_sources().sources["typing"] self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) + self.ratelimiter.can_do_action.return_value = (True, 0) hs.get_handlers().federation_handler = Mock() diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 906b348d3e..3600434858 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -130,3 +130,51 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Guest access is disabled") + + def test_POST_ratelimiting_guest(self): + self.hs.config.rc_registration_request_burst_count = 5 + + for i in range(0, 6): + url = self.url + b"?kind=guest" + request, channel = self.make_request(b"POST", url, b"{}") + self.render(request) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(retry_after_ms / 1000.) + + request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + self.render(request) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_POST_ratelimiting(self): + self.hs.config.rc_registration_request_burst_count = 5 + + for i in range(0, 6): + params = { + "username": "kermit" + str(i), + "password": "monkey", + "device_id": "frogfone", + "auth": {"type": LoginType.DUMMY}, + } + request_data = json.dumps(params) + request, channel = self.make_request(b"POST", self.url, request_data) + self.render(request) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(retry_after_ms / 1000.) + + request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + self.render(request) + + self.assertEquals(channel.result["code"], b"200", channel.result) diff --git a/tests/test_mau.py b/tests/test_mau.py index 04f95c942f..00be1a8c21 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -17,7 +17,7 @@ import json -from mock import Mock, NonCallableMock +from mock import Mock from synapse.api.constants import LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError @@ -36,7 +36,6 @@ class TestMauLimit(unittest.HomeserverTestCase): "red", http_client=None, federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), ) self.store = self.hs.get_datastore() diff --git a/tests/utils.py b/tests/utils.py index ee272157aa..e4c42f9fa8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -150,6 +150,8 @@ def default_config(name): config.admin_contact = None config.rc_messages_per_second = 10000 config.rc_message_burst_count = 10000 + config.rc_registration_request_burst_count = 3.0 + config.rc_registration_requests_per_second = 0.17 config.saml2_enabled = False config.public_baseurl = None config.default_identity_server = None -- cgit 1.5.1