diff --git a/tests/config/test_server.py b/tests/config/test_server.py
index a10d017120..98af7aa675 100644
--- a/tests/config/test_server.py
+++ b/tests/config/test_server.py
@@ -15,7 +15,8 @@
import yaml
-from synapse.config.server import ServerConfig, is_threepid_reserved
+from synapse.config._base import ConfigError
+from synapse.config.server import ServerConfig, generate_ip_set, is_threepid_reserved
from tests import unittest
@@ -128,3 +129,61 @@ class ServerConfigTestCase(unittest.TestCase):
)
self.assertEqual(conf["listeners"], expected_listeners)
+
+
+class GenerateIpSetTestCase(unittest.TestCase):
+ def test_empty(self):
+ ip_set = generate_ip_set(())
+ self.assertFalse(ip_set)
+
+ ip_set = generate_ip_set((), ())
+ self.assertFalse(ip_set)
+
+ def test_generate(self):
+ """Check adding IPv4 and IPv6 addresses."""
+ # IPv4 address
+ ip_set = generate_ip_set(("1.2.3.4",))
+ self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+ # IPv4 CIDR
+ ip_set = generate_ip_set(("1.2.3.4/24",))
+ self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+ # IPv6 address
+ ip_set = generate_ip_set(("2001:db8::8a2e:370:7334",))
+ self.assertEqual(len(ip_set.iter_cidrs()), 1)
+
+ # IPv6 CIDR
+ ip_set = generate_ip_set(("2001:db8::/104",))
+ self.assertEqual(len(ip_set.iter_cidrs()), 1)
+
+ # The addresses can overlap OK.
+ ip_set = generate_ip_set(("1.2.3.4", "::1.2.3.4"))
+ self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+ def test_extra(self):
+ """Extra IP addresses are treated the same."""
+ ip_set = generate_ip_set((), ("1.2.3.4",))
+ self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+ ip_set = generate_ip_set(("1.1.1.1",), ("1.2.3.4",))
+ self.assertEqual(len(ip_set.iter_cidrs()), 8)
+
+ # They can duplicate without error.
+ ip_set = generate_ip_set(("1.2.3.4",), ("1.2.3.4",))
+ self.assertEqual(len(ip_set.iter_cidrs()), 4)
+
+ def test_bad_value(self):
+ """An error should be raised if a bad value is passed in."""
+ with self.assertRaises(ConfigError):
+ generate_ip_set(("not-an-ip",))
+
+ with self.assertRaises(ConfigError):
+ generate_ip_set(("1.2.3.4/128",))
+
+ with self.assertRaises(ConfigError):
+ generate_ip_set((":::",))
+
+ # The following get treated as empty data.
+ self.assertFalse(generate_ip_set(None))
+ self.assertFalse(generate_ip_set({}))
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index c37bb6440e..7baf224f7e 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -62,7 +62,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None
+ "@test_user:test", request, "redirect_uri", None, new_user=True
)
def test_map_cas_user_to_existing_user(self):
@@ -85,7 +85,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None
+ "@test_user:test", request, "redirect_uri", None, new_user=False
)
# Subsequent calls should map to the same mxid.
@@ -94,7 +94,7 @@ class CasHandlerTestCase(HomeserverTestCase):
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
)
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None
+ "@test_user:test", request, "redirect_uri", None, new_user=False
)
def test_map_cas_user_to_invalid_localpart(self):
@@ -112,7 +112,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@f=c3=b6=c3=b6:test", request, "redirect_uri", None
+ "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 0b24b89a2e..983e368592 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -16,7 +16,7 @@ import logging
from unittest import TestCase
from synapse.api.constants import EventTypes
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.federation.federation_base import event_from_pdu_json
@@ -191,6 +191,50 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(sg, sg2)
+ @unittest.override_config(
+ {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invite_by_user_ratelimit(self):
+ """Tests that invites from federation to a particular user are
+ actually rate-limited.
+ """
+ other_server = "otherserver"
+ other_user = "@otheruser:" + other_server
+
+ # create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+
+ def create_invite():
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(self.store.get_room_version(room_id))
+ return event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": "invite"},
+ "room_id": room_id,
+ "sender": other_user,
+ "state_key": "@user:test",
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ for i in range(3):
+ event = create_invite()
+ self.get_success(
+ self.handler.on_invite_request(other_server, event, event.room_version,)
+ )
+
+ event = create_invite()
+ self.get_failure(
+ self.handler.on_invite_request(other_server, event, event.room_version,),
+ exc=LimitExceededError,
+ )
+
def _build_and_send_join_event(self, other_server, other_user, room_id):
join_event = self.get_success(
self.handler.on_make_join_request(other_server, room_id, other_user)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index b3dfa40d25..ad20400b1d 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -40,7 +40,7 @@ ISSUER = "https://issuer/"
CLIENT_ID = "test-client-id"
CLIENT_SECRET = "test-client-secret"
BASE_URL = "https://synapse/"
-CALLBACK_URL = BASE_URL + "_synapse/oidc/callback"
+CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback"
SCOPES = ["openid"]
AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
@@ -58,12 +58,6 @@ COMMON_CONFIG = {
}
-# The cookie name and path don't really matter, just that it has to be coherent
-# between the callback & redirect handlers.
-COOKIE_NAME = b"oidc_session"
-COOKIE_PATH = "/_synapse/oidc"
-
-
class TestMappingProvider:
@staticmethod
def parse_config(config):
@@ -340,8 +334,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
# For some reason, call.args does not work with python3.5
args = calls[0][0]
kwargs = calls[0][1]
- self.assertEqual(args[0], COOKIE_NAME)
- self.assertEqual(kwargs["path"], COOKIE_PATH)
+
+ # The cookie name and path don't really matter, just that it has to be coherent
+ # between the callback & redirect handlers.
+ self.assertEqual(args[0], b"oidc_session")
+ self.assertEqual(kwargs["path"], "/_synapse/client/oidc")
cookie = args[1]
macaroon = pymacaroons.Macaroon.deserialize(cookie)
@@ -419,7 +416,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
- expected_user_id, request, client_redirect_url, None,
+ expected_user_id, request, client_redirect_url, None, new_user=True
)
self.provider._exchange_code.assert_called_once_with(code)
self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
@@ -450,7 +447,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
- expected_user_id, request, client_redirect_url, None,
+ expected_user_id, request, client_redirect_url, None, new_user=False
)
self.provider._exchange_code.assert_called_once_with(code)
self.provider._parse_id_token.assert_not_called()
@@ -623,7 +620,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
- "@foo:test", request, client_redirect_url, {"phone": "1234567"},
+ "@foo:test",
+ request,
+ client_redirect_url,
+ {"phone": "1234567"},
+ new_user=True,
)
def test_map_userinfo_to_user(self):
@@ -637,7 +638,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", ANY, ANY, None,
+ "@test_user:test", ANY, ANY, None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -648,7 +649,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user_2:test", ANY, ANY, None,
+ "@test_user_2:test", ANY, ANY, None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -685,14 +686,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- user.to_string(), ANY, ANY, None,
+ user.to_string(), ANY, ANY, None, new_user=False
)
auth_handler.complete_sso_login.reset_mock()
# Subsequent calls should map to the same mxid.
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- user.to_string(), ANY, ANY, None,
+ user.to_string(), ANY, ANY, None, new_user=False
)
auth_handler.complete_sso_login.reset_mock()
@@ -707,7 +708,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- user.to_string(), ANY, ANY, None,
+ user.to_string(), ANY, ANY, None, new_user=False
)
auth_handler.complete_sso_login.reset_mock()
@@ -743,7 +744,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- "@TEST_USER_2:test", ANY, ANY, None,
+ "@TEST_USER_2:test", ANY, ANY, None, new_user=False
)
def test_map_userinfo_to_invalid_localpart(self):
@@ -779,7 +780,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
# test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user1:test", ANY, ANY, None,
+ "@test_user1:test", ANY, ANY, None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 261c7083d1..a8d6c0f617 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None
+ "@test_user:test", request, "redirect_uri", None, new_user=True
)
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
@@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "", None
+ "@test_user:test", request, "", None, new_user=False
)
# Subsequent calls should map to the same mxid.
@@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
self.handler._handle_authn_response(request, saml_response, "")
)
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "", None
+ "@test_user:test", request, "", None, new_user=False
)
def test_map_saml_response_to_invalid_localpart(self):
@@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user1:test", request, "", None
+ "@test_user1:test", request, "", None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 961bf09de9..22f452ec24 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -124,13 +124,18 @@ class EmailPusherTests(HomeserverTestCase):
)
self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
- # The other user sends some messages
+ # The other user sends a single message.
self.helper.send(room, body="Hi!", tok=self.others[0].token)
- self.helper.send(room, body="There!", tok=self.others[0].token)
# We should get emailed about that message
self._check_for_mail()
+ # The other user sends multiple messages.
+ self.helper.send(room, body="Hi!", tok=self.others[0].token)
+ self.helper.send(room, body="There!", tok=self.others[0].token)
+
+ self._check_for_mail()
+
def test_invite_sends_email(self):
# Create a room and invite the user to it
room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token)
@@ -187,6 +192,75 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about those messages
self._check_for_mail()
+ def test_multiple_rooms(self):
+ # We want to test multiple notifications from multiple rooms, so we pause
+ # processing of push while we send messages.
+ self.pusher._pause_processing()
+
+ # Create a simple room with multiple other users
+ rooms = [
+ self.helper.create_room_as(self.user_id, tok=self.access_token),
+ self.helper.create_room_as(self.user_id, tok=self.access_token),
+ ]
+
+ for r, other in zip(rooms, self.others):
+ self.helper.invite(
+ room=r, src=self.user_id, tok=self.access_token, targ=other.id
+ )
+ self.helper.join(room=r, user=other.id, tok=other.token)
+
+ # The other users send some messages
+ self.helper.send(rooms[0], body="Hi!", tok=self.others[0].token)
+ self.helper.send(rooms[1], body="There!", tok=self.others[1].token)
+ self.helper.send(rooms[1], body="There!", tok=self.others[1].token)
+
+ # Nothing should have happened yet, as we're paused.
+ assert not self.email_attempts
+
+ self.pusher._resume_processing()
+
+ # We should get emailed about those messages
+ self._check_for_mail()
+
+ def test_empty_room(self):
+ """All users leaving a room shouldn't cause the pusher to break."""
+ # Create a simple room with two users
+ room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+ self.helper.invite(
+ room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
+ )
+ self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
+
+ # The other user sends a single message.
+ self.helper.send(room, body="Hi!", tok=self.others[0].token)
+
+ # Leave the room before the message is processed.
+ self.helper.leave(room, self.user_id, tok=self.access_token)
+ self.helper.leave(room, self.others[0].id, tok=self.others[0].token)
+
+ # We should get emailed about that message
+ self._check_for_mail()
+
+ def test_empty_room_multiple_messages(self):
+ """All users leaving a room shouldn't cause the pusher to break."""
+ # Create a simple room with two users
+ room = self.helper.create_room_as(self.user_id, tok=self.access_token)
+ self.helper.invite(
+ room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
+ )
+ self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
+
+ # The other user sends a single message.
+ self.helper.send(room, body="Hi!", tok=self.others[0].token)
+ self.helper.send(room, body="There!", tok=self.others[0].token)
+
+ # Leave the room before the message is processed.
+ self.helper.leave(room, self.user_id, tok=self.access_token)
+ self.helper.leave(room, self.others[0].id, tok=self.others[0].token)
+
+ # We should get emailed about that message
+ self._check_for_mail()
+
def test_encrypted_message(self):
room = self.helper.create_room_as(self.user_id, tok=self.access_token)
self.helper.invite(
@@ -239,3 +313,6 @@ class EmailPusherTests(HomeserverTestCase):
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
+
+ # Reset the attempts.
+ self.email_attempts = []
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index f4f5569a89..2a217b1ce0 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1180,6 +1180,21 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.json_body["total"], 3)
+ def test_room_state(self):
+ """Test that room state can be requested correctly"""
+ # Create two test rooms
+ room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ url = "/_synapse/admin/v1/rooms/%s/state" % (room_id,)
+ channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertIn("state", channel.json_body)
+ # testing that the state events match is painful and not done here. We assume that
+ # the create_room already does the right thing, so no need to verify that we got
+ # the state events it created.
+
class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index e2bb945453..ceb4ad2366 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -15,7 +15,7 @@
import time
import urllib.parse
-from typing import Any, Dict, Union
+from typing import Any, Dict, List, Union
from urllib.parse import urlencode
from mock import Mock
@@ -29,8 +29,7 @@ from synapse.appservice import ApplicationService
from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import devices, register
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
-from synapse.rest.synapse.client.pick_idp import PickIdpResource
-from synapse.rest.synapse.client.pick_username import pick_username_resource
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.types import create_requester
from tests import unittest
@@ -423,11 +422,8 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
return config
def create_resource_dict(self) -> Dict[str, Resource]:
- from synapse.rest.oidc import OIDCResource
-
d = super().create_resource_dict()
- d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs)
- d["/_synapse/oidc"] = OIDCResource(self.hs)
+ d.update(build_synapse_client_resource_tree(self.hs))
return d
def test_get_login_flows(self):
@@ -497,13 +493,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
# parse the form to check it has fields assumed elsewhere in this class
+ html = channel.result["body"].decode("utf-8")
p = TestHtmlParser()
- p.feed(channel.result["body"].decode("utf-8"))
+ p.feed(html)
p.close()
- self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"])
+ # there should be a link for each href
+ returned_idps = [] # type: List[str]
+ for link in p.links:
+ path, query = link.split("?", 1)
+ self.assertEqual(path, "pick_idp")
+ params = urllib.parse.parse_qs(query)
+ self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL])
+ returned_idps.append(params["idp"][0])
- self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL)
+ self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"])
def test_multi_sso_redirect_to_cas(self):
"""If CAS is chosen, should redirect to the CAS server"""
@@ -1211,11 +1215,8 @@ class UsernamePickerTestCase(HomeserverTestCase):
return config
def create_resource_dict(self) -> Dict[str, Resource]:
- from synapse.rest.oidc import OIDCResource
-
d = super().create_resource_dict()
- d["/_synapse/client/pick_username"] = pick_username_resource(self.hs)
- d["/_synapse/oidc"] = OIDCResource(self.hs)
+ d.update(build_synapse_client_resource_tree(self.hs))
return d
def test_username_picker(self):
@@ -1229,7 +1230,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
# that should redirect to the username picker
self.assertEqual(channel.code, 302, channel.result)
picker_url = channel.headers.getRawHeaders("Location")[0]
- self.assertEqual(picker_url, "/_synapse/client/pick_username")
+ self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
# ... with a username_mapping_session cookie
cookies = {} # type: Dict[str,str]
@@ -1253,12 +1254,11 @@ class UsernamePickerTestCase(HomeserverTestCase):
self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
# Now, submit a username to the username picker, which should serve a redirect
- # back to the client
- submit_path = picker_url + "/submit"
+ # to the completion page
content = urlencode({b"username": b"bobby"}).encode("utf8")
chan = self.make_request(
"POST",
- path=submit_path,
+ path=picker_url,
content=content,
content_is_form=True,
custom_headers=[
@@ -1270,6 +1270,16 @@ class UsernamePickerTestCase(HomeserverTestCase):
)
self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location")
+
+ # send a request to the completion page, which should 302 to the client redirectUrl
+ chan = self.make_request(
+ "GET",
+ path=location_headers[0],
+ custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
+ )
+ self.assertEqual(chan.code, 302, chan.result)
+ location_headers = chan.headers.getRawHeaders("Location")
+
# ensure that the returned location matches the requested redirect URL
path, query = location_headers[0].split("?", 1)
self.assertEqual(path, "https://x")
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index d4e3165436..2548b3a80c 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -616,6 +616,41 @@ class RoomMemberStateTestCase(RoomBase):
self.assertEquals(json.loads(content), channel.json_body)
+class RoomInviteRatelimitTestCase(RoomBase):
+ user_id = "@sid1:red"
+
+ servlets = [
+ admin.register_servlets,
+ profile.register_servlets,
+ room.register_servlets,
+ ]
+
+ @unittest.override_config(
+ {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invites_by_rooms_ratelimit(self):
+ """Tests that invites in a room are actually rate-limited."""
+ room_id = self.helper.create_room_as(self.user_id)
+
+ for i in range(3):
+ self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,))
+
+ self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429)
+
+ @unittest.override_config(
+ {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invites_by_users_ratelimit(self):
+ """Tests that invites to a specific user are actually rate-limited."""
+
+ for i in range(3):
+ room_id = self.helper.create_room_as(self.user_id)
+ self.helper.invite(room_id, self.user_id, "@other-users:red")
+
+ room_id = self.helper.create_room_as(self.user_id)
+ self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429)
+
+
class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index cb87b80e33..177dc476da 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -24,7 +24,7 @@ import pkg_resources
import synapse.rest.admin
from synapse.api.constants import LoginType, Membership
-from synapse.api.errors import Codes
+from synapse.api.errors import Codes, HttpResponseException
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
@@ -112,6 +112,56 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Assert we can't log in with the old password
self.attempt_wrong_password_login("kermit", old_password)
+ @override_config({"rc_3pid_validation": {"burst_count": 3}})
+ def test_ratelimit_by_email(self):
+ """Test that we ratelimit /requestToken for the same email.
+ """
+ old_password = "monkey"
+ new_password = "kangeroo"
+
+ user_id = self.register_user("kermit", old_password)
+ self.login("kermit", old_password)
+
+ email = "test1@example.com"
+
+ # Add a threepid
+ self.get_success(
+ self.store.user_add_threepid(
+ user_id=user_id,
+ medium="email",
+ address=email,
+ validated_at=0,
+ added_at=0,
+ )
+ )
+
+ def reset(ip):
+ client_secret = "foobar"
+ session_id = self._request_token(email, client_secret, ip)
+
+ self.assertEquals(len(self.email_attempts), 1)
+ link = self._get_link_from_email()
+
+ self._validate_token(link)
+
+ self._reset_password(new_password, session_id, client_secret)
+
+ self.email_attempts.clear()
+
+ # We expect to be able to make three requests before getting rate
+ # limited.
+ #
+ # We change IPs to ensure that we're not being ratelimited due to the
+ # same IP
+ reset("127.0.0.1")
+ reset("127.0.0.2")
+ reset("127.0.0.3")
+
+ with self.assertRaises(HttpResponseException) as cm:
+ reset("127.0.0.4")
+
+ self.assertEqual(cm.exception.code, 429)
+
def test_basic_password_reset_canonicalise_email(self):
"""Test basic password reset flow
Request password reset with different spelling
@@ -239,13 +289,18 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.assertIsNotNone(session_id)
- def _request_token(self, email, client_secret):
+ def _request_token(self, email, client_secret, ip="127.0.0.1"):
channel = self.make_request(
"POST",
b"account/password/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
+ client_ip=ip,
)
- self.assertEquals(200, channel.code, channel.result)
+
+ if channel.code != 200:
+ raise HttpResponseException(
+ channel.code, channel.result["reason"], channel.result["body"],
+ )
return channel.json_body["sid"]
@@ -509,6 +564,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def test_address_trim(self):
self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
+ @override_config({"rc_3pid_validation": {"burst_count": 3}})
+ def test_ratelimit_by_ip(self):
+ """Tests that adding emails is ratelimited by IP
+ """
+
+ # We expect to be able to set three emails before getting ratelimited.
+ self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
+ self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar"))
+ self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar"))
+
+ with self.assertRaises(HttpResponseException) as cm:
+ self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar"))
+
+ self.assertEqual(cm.exception.code, 429)
+
def test_add_email_if_disabled(self):
"""Test adding email to profile when doing so is disallowed
"""
@@ -777,7 +847,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
body["next_link"] = next_link
channel = self.make_request("POST", b"account/3pid/email/requestToken", body,)
- self.assertEquals(expect_code, channel.code, channel.result)
+
+ if channel.code != expect_code:
+ raise HttpResponseException(
+ channel.code, channel.result["reason"], channel.result["body"],
+ )
return channel.json_body.get("sid")
@@ -823,10 +897,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def _add_email(self, request_email, expected_email):
"""Test adding an email to profile
"""
+ previous_email_attempts = len(self.email_attempts)
+
client_secret = "foobar"
session_id = self._request_token(request_email, client_secret)
- self.assertEquals(len(self.email_attempts), 1)
+ self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1)
link = self._get_link_from_email()
self._validate_token(link)
@@ -855,4 +931,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"])
+
+ threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
+ self.assertIn(expected_email, threepids)
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index a6488a3d29..3f50c56745 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -22,7 +22,7 @@ from synapse.api.constants import LoginType
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import auth, devices, register
-from synapse.rest.oidc import OIDCResource
+from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.types import JsonDict, UserID
from tests import unittest
@@ -173,9 +173,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
def create_resource_dict(self):
resource_dict = super().create_resource_dict()
- if HAS_OIDC:
- # mount the OIDC resource at /_synapse/oidc
- resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
+ resource_dict.update(build_synapse_client_resource_tree(self.hs))
return resource_dict
def prepare(self, reactor, clock, hs):
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index a6c6985173..c279eb49e3 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -30,6 +30,8 @@ from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.logging.context import make_deferred_yieldable
+from synapse.rest import admin
+from synapse.rest.client.v1 import login
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage
@@ -37,6 +39,7 @@ from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from tests import unittest
from tests.server import FakeSite, make_request
+from tests.utils import default_config
class MediaStorageTests(unittest.HomeserverTestCase):
@@ -398,3 +401,94 @@ class MediaRepoTests(unittest.HomeserverTestCase):
headers.getRawHeaders(b"X-Robots-Tag"),
[b"noindex, nofollow, noarchive, noimageindex"],
)
+
+
+class TestSpamChecker:
+ """A spam checker module that rejects all media that includes the bytes
+ `evil`.
+ """
+
+ def __init__(self, config, api):
+ self.config = config
+ self.api = api
+
+ def parse_config(config):
+ return config
+
+ async def check_event_for_spam(self, foo):
+ return False # allow all events
+
+ async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+ return True # allow all invites
+
+ async def user_may_create_room(self, userid):
+ return True # allow all room creations
+
+ async def user_may_create_room_alias(self, userid, room_alias):
+ return True # allow all room aliases
+
+ async def user_may_publish_room(self, userid, room_id):
+ return True # allow publishing of all rooms
+
+ async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool:
+ buf = BytesIO()
+ await file_wrapper.write_chunks_to(buf.write)
+
+ return b"evil" in buf.getvalue()
+
+
+class SpamCheckerTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ admin.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.user = self.register_user("user", "pass")
+ self.tok = self.login("user", "pass")
+
+ # Allow for uploading and downloading to/from the media repo
+ self.media_repo = hs.get_media_repository_resource()
+ self.download_resource = self.media_repo.children[b"download"]
+ self.upload_resource = self.media_repo.children[b"upload"]
+
+ def default_config(self):
+ config = default_config("test")
+
+ config.update(
+ {
+ "spam_checker": [
+ {
+ "module": TestSpamChecker.__module__ + ".TestSpamChecker",
+ "config": {},
+ }
+ ]
+ }
+ )
+
+ return config
+
+ def test_upload_innocent(self):
+ """Attempt to upload some innocent data that should be allowed.
+ """
+
+ image_data = unhexlify(
+ b"89504e470d0a1a0a0000000d4948445200000001000000010806"
+ b"0000001f15c4890000000a49444154789c63000100000500010d"
+ b"0a2db40000000049454e44ae426082"
+ )
+
+ self.helper.upload_media(
+ self.upload_resource, image_data, tok=self.tok, expect_code=200
+ )
+
+ def test_upload_ban(self):
+ """Attempt to upload some data that includes bytes "evil", which should
+ get rejected by the spam checker.
+ """
+
+ data = b"Some evil data"
+
+ self.helper.upload_media(
+ self.upload_resource, data, tok=self.tok, expect_code=400
+ )
diff --git a/tests/server.py b/tests/server.py
index 5a85d5fe7f..6419c445ec 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -47,6 +47,7 @@ class FakeChannel:
site = attr.ib(type=Site)
_reactor = attr.ib()
result = attr.ib(type=dict, default=attr.Factory(dict))
+ _ip = attr.ib(type=str, default="127.0.0.1")
_producer = None
@property
@@ -120,7 +121,7 @@ class FakeChannel:
def getPeer(self):
# We give an address so that getClientIP returns a non null entry,
# causing us to record the MAU
- return address.IPv4Address("TCP", "127.0.0.1", 3423)
+ return address.IPv4Address("TCP", self._ip, 3423)
def getHost(self):
return None
@@ -196,6 +197,7 @@ def make_request(
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
+ client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
Make a web request using the given method, path and content, and render it
@@ -223,6 +225,9 @@ def make_request(
will pump the reactor until the the renderer tells the channel the request
is finished.
+ client_ip: The IP to use as the requesting IP. Useful for testing
+ ratelimiting.
+
Returns:
channel
"""
@@ -250,7 +255,7 @@ def make_request(
if isinstance(content, str):
content = content.encode("utf8")
- channel = FakeChannel(site, reactor)
+ channel = FakeChannel(site, reactor, ip=client_ip)
req = request(channel)
req.content = BytesIO(content)
diff --git a/tests/unittest.py b/tests/unittest.py
index bbd295687c..767d5d6077 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -386,6 +386,7 @@ class HomeserverTestCase(TestCase):
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
+ client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
Create a SynapseRequest at the path using the method and containing the
@@ -410,6 +411,9 @@ class HomeserverTestCase(TestCase):
custom_headers: (name, value) pairs to add as request headers
+ client_ip: The IP to use as the requesting IP. Useful for testing
+ ratelimiting.
+
Returns:
The FakeChannel object which stores the result of the request.
"""
@@ -426,6 +430,7 @@ class HomeserverTestCase(TestCase):
content_is_form,
await_result,
custom_headers,
+ client_ip,
)
def setup_test_homeserver(self, *args, **kwargs):
diff --git a/tests/utils.py b/tests/utils.py
index 022223cf24..68033d7535 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -157,6 +157,7 @@ def default_config(name, parse=False):
"local": {"per_second": 10000, "burst_count": 10000},
"remote": {"per_second": 10000, "burst_count": 10000},
},
+ "rc_3pid_validation": {"per_second": 10000, "burst_count": 10000},
"saml2_enabled": False,
"default_identity_server": None,
"key_refresh_interval": 24 * 60 * 60 * 1000,
|