diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 9033f09fd2..2668662c9e 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -62,8 +62,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -76,14 +75,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -111,8 +109,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -132,7 +129,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -160,8 +156,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -174,14 +169,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -398,7 +392,7 @@ class CASTestCase(unittest.HomeserverTestCase):
</cas:serviceResponse>
"""
% cas_user_id
- )
+ ).encode("utf-8")
mocked_http_client = Mock(spec=["get_raw"])
mocked_http_client.get_raw.side_effect = get_raw
@@ -514,19 +508,22 @@ class JWTTestCase(unittest.HomeserverTestCase):
]
jwt_secret = "secret"
+ jwt_algorithm = "HS256"
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
self.hs.config.jwt_enabled = True
self.hs.config.jwt_secret = self.jwt_secret
- self.hs.config.jwt_algorithm = "HS256"
+ self.hs.config.jwt_algorithm = self.jwt_algorithm
return self.hs
def jwt_encode(self, token, secret=jwt_secret):
- return jwt.encode(token, secret, "HS256").decode("ascii")
+ return jwt.encode(token, secret, self.jwt_algorithm).decode("ascii")
def jwt_login(self, *args):
- params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
+ params = json.dumps(
+ {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
+ )
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
return channel
@@ -544,35 +541,126 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_invalid_signature(self):
channel = self.jwt_login({"sub": "frog"}, "notsecret")
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
- self.assertEqual(channel.json_body["error"], "Invalid JWT")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"],
+ "JWT validation failed: Signature verification failed",
+ )
def test_login_jwt_expired(self):
channel = self.jwt_login({"sub": "frog", "exp": 864000})
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
- self.assertEqual(channel.json_body["error"], "JWT expired")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"], "JWT validation failed: Signature has expired"
+ )
def test_login_jwt_not_before(self):
now = int(time.time())
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
- self.assertEqual(channel.json_body["error"], "Invalid JWT")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"],
+ "JWT validation failed: The token is not yet valid (nbf)",
+ )
def test_login_no_sub(self):
channel = self.jwt_login({"username": "root"})
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
+ @override_config(
+ {
+ "jwt_config": {
+ "jwt_enabled": True,
+ "secret": jwt_secret,
+ "algorithm": jwt_algorithm,
+ "issuer": "test-issuer",
+ }
+ }
+ )
+ def test_login_iss(self):
+ """Test validating the issuer claim."""
+ # A valid issuer.
+ channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ # An invalid issuer.
+ channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"], "JWT validation failed: Invalid issuer"
+ )
+
+ # Not providing an issuer.
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"],
+ 'JWT validation failed: Token is missing the "iss" claim',
+ )
+
+ def test_login_iss_no_config(self):
+ """Test providing an issuer claim without requiring it in the configuration."""
+ channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ @override_config(
+ {
+ "jwt_config": {
+ "jwt_enabled": True,
+ "secret": jwt_secret,
+ "algorithm": jwt_algorithm,
+ "audiences": ["test-audience"],
+ }
+ }
+ )
+ def test_login_aud(self):
+ """Test validating the audience claim."""
+ # A valid audience.
+ channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ self.assertEqual(channel.json_body["user_id"], "@kermit:test")
+
+ # An invalid audience.
+ channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"], "JWT validation failed: Invalid audience"
+ )
+
+ # Not providing an audience.
+ channel = self.jwt_login({"sub": "kermit"})
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"],
+ 'JWT validation failed: Token is missing the "aud" claim',
+ )
+
+ def test_login_aud_no_config(self):
+ """Test providing an audience without requiring it in the configuration."""
+ channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"], "JWT validation failed: Invalid audience"
+ )
+
def test_login_no_token(self):
- params = json.dumps({"type": "m.login.jwt"})
+ params = json.dumps({"type": "org.matrix.login.jwt"})
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
@@ -640,7 +728,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
return jwt.encode(token, secret, "RS256").decode("ascii")
def jwt_login(self, *args):
- params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
+ params = json.dumps(
+ {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
+ )
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
return channel
@@ -652,6 +742,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
def test_login_jwt_invalid_signature(self):
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
- self.assertEqual(channel.result["code"], b"401", channel.result)
- self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
- self.assertEqual(channel.json_body["error"], "Invalid JWT")
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ self.assertEqual(
+ channel.json_body["error"],
+ "JWT validation failed: Signature verification failed",
+ )
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 0fdff79aa7..3c66255dac 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -60,7 +60,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def test_put_presence_disabled(self):
"""
- PUT to the status endpoint with use_presence disbled will NOT call
+ PUT to the status endpoint with use_presence disabled will NOT call
set_state on the presence handler.
"""
self.hs.config.use_presence = False
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 8df58b4a63..ace0a3c08d 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -70,8 +70,8 @@ class MockHandlerProfileTestCase(unittest.TestCase):
profile_handler=self.mock_handler,
)
- def _get_user_by_req(request=None, allow_guest=False):
- return defer.succeed(synapse.types.create_requester(myid))
+ async def _get_user_by_req(request=None, allow_guest=False):
+ return synapse.types.create_requester(myid)
hs.get_auth().get_user_by_req = _get_user_by_req
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 4886bbb401..0a567b032f 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -19,18 +19,16 @@
"""Tests REST events for /rooms paths."""
import json
+from urllib import parse as urlparse
from mock import Mock
-from six.moves.urllib import parse as urlparse
-
-from twisted.internet import defer
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
from synapse.rest.client.v1 import directory, login, profile, room
from synapse.rest.client.v2_alpha import account
-from synapse.types import JsonDict, RoomAlias
+from synapse.types import JsonDict, RoomAlias, UserID
from synapse.util.stringutils import random_string
from tests import unittest
@@ -51,8 +49,8 @@ class RoomBase(unittest.HomeserverTestCase):
self.hs.get_federation_handler = Mock(return_value=Mock())
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
+ async def _insert_client_ip(*args, **kwargs):
+ return None
self.hs.get_datastore().insert_client_ip = _insert_client_ip
@@ -677,6 +675,92 @@ class RoomMemberStateTestCase(RoomBase):
self.assertEquals(json.loads(content), channel.json_body)
+class RoomJoinRatelimitTestCase(RoomBase):
+ user_id = "@sid1:red"
+
+ servlets = [
+ profile.register_servlets,
+ room.register_servlets,
+ ]
+
+ @unittest.override_config(
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_join_local_ratelimit(self):
+ """Tests that local joins are actually rate-limited."""
+ for i in range(3):
+ self.helper.create_room_as(self.user_id)
+
+ self.helper.create_room_as(self.user_id, expect_code=429)
+
+ @unittest.override_config(
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_join_local_ratelimit_profile_change(self):
+ """Tests that sending a profile update into all of the user's joined rooms isn't
+ rate-limited by the rate-limiter on joins."""
+
+ # Create and join as many rooms as the rate-limiting config allows in a second.
+ room_ids = [
+ self.helper.create_room_as(self.user_id),
+ self.helper.create_room_as(self.user_id),
+ self.helper.create_room_as(self.user_id),
+ ]
+ # Let some time for the rate-limiter to forget about our multi-join.
+ self.reactor.advance(2)
+ # Add one to make sure we're joined to more rooms than the config allows us to
+ # join in a second.
+ room_ids.append(self.helper.create_room_as(self.user_id))
+
+ # Create a profile for the user, since it hasn't been done on registration.
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.create_profile(UserID.from_string(self.user_id).localpart)
+ )
+
+ # Update the display name for the user.
+ path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
+ request, channel = self.make_request("PUT", path, {"displayname": "John Doe"})
+ self.render(request)
+ self.assertEquals(channel.code, 200, channel.json_body)
+
+ # Check that all the rooms have been sent a profile update into.
+ for room_id in room_ids:
+ path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (
+ room_id,
+ self.user_id,
+ )
+
+ request, channel = self.make_request("GET", path)
+ self.render(request)
+ self.assertEquals(channel.code, 200)
+
+ self.assertIn("displayname", channel.json_body)
+ self.assertEquals(channel.json_body["displayname"], "John Doe")
+
+ @unittest.override_config(
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_join_local_ratelimit_idempotent(self):
+ """Tests that the room join endpoints remain idempotent despite rate-limiting
+ on room joins."""
+ room_id = self.helper.create_room_as(self.user_id)
+
+ # Let's test both paths to be sure.
+ paths_to_test = [
+ "/_matrix/client/r0/rooms/%s/join",
+ "/_matrix/client/r0/join/%s",
+ ]
+
+ for path in paths_to_test:
+ # Make sure we send more requests than the rate-limiting config would allow
+ # if all of these requests ended up joining the user to a room.
+ for i in range(4):
+ request, channel = self.make_request("POST", path % room_id, {})
+ self.render(request)
+ self.assertEquals(channel.code, 200)
+
+
class RoomMessagesTestCase(RoomBase):
""" Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 18260bb90e..94d2bf2eb1 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -46,7 +46,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_handlers().federation_handler = Mock()
- def get_user_by_access_token(token=None, allow_guest=False):
+ async def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
@@ -55,8 +55,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_auth().get_user_by_access_token = get_user_by_access_token
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
+ async def _insert_client_ip(*args, **kwargs):
+ return None
hs.get_datastore().insert_client_ip = _insert_client_ip
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 22d734e763..afaf9f7b85 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -30,7 +30,7 @@ from tests.server import make_request, render
@attr.s
-class RestHelper(object):
+class RestHelper:
"""Contains extra helper functions to quickly and clearly perform a given
REST action, which isn't the focus of the test.
"""
@@ -39,7 +39,9 @@ class RestHelper(object):
resource = attr.ib()
auth_user_id = attr.ib()
- def create_room_as(self, room_creator=None, is_public=True, tok=None):
+ def create_room_as(
+ self, room_creator=None, is_public=True, tok=None, expect_code=200,
+ ):
temp_id = self.auth_user_id
self.auth_user_id = room_creator
path = "/_matrix/client/r0/createRoom"
@@ -54,9 +56,11 @@ class RestHelper(object):
)
render(request, self.resource, self.hs.get_reactor())
- assert channel.result["code"] == b"200", channel.result
+ assert channel.result["code"] == b"%d" % expect_code, channel.result
self.auth_user_id = temp_id
- return channel.json_body["room_id"]
+
+ if expect_code == 200:
+ return channel.json_body["room_id"]
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
self.change_membership(
@@ -88,7 +92,28 @@ class RestHelper(object):
expect_code=expect_code,
)
- def change_membership(self, room, src, targ, membership, tok=None, expect_code=200):
+ def change_membership(
+ self,
+ room: str,
+ src: str,
+ targ: str,
+ membership: str,
+ extra_data: dict = {},
+ tok: Optional[str] = None,
+ expect_code: int = 200,
+ ) -> None:
+ """
+ Send a membership state event into a room.
+
+ Args:
+ room: The ID of the room to send to
+ src: The mxid of the event sender
+ targ: The mxid of the event's target. The state key
+ membership: The type of membership event
+ extra_data: Extra information to include in the content of the event
+ tok: The user access token to use
+ expect_code: The expected HTTP response code
+ """
temp_id = self.auth_user_id
self.auth_user_id = src
@@ -97,6 +122,7 @@ class RestHelper(object):
path = path + "?access_token=%s" % tok
data = {"membership": membership}
+ data.update(extra_data)
request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8")
|