diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py
index c973521907..4224b0a92e 100644
--- a/tests/rest/client/test_identity.py
+++ b/tests/rest/client/test_identity.py
@@ -15,15 +15,22 @@
import json
+from mock import Mock
+
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import account
from tests import unittest
-class IdentityTestCase(unittest.HomeserverTestCase):
+class IdentityDisabledTestCase(unittest.HomeserverTestCase):
+ """Tests that 3PID lookup attempts fail when the HS's config disallows them."""
servlets = [
+ account.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
@@ -32,24 +39,111 @@ class IdentityTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["trusted_third_party_id_servers"] = ["testis"]
config["enable_3pid_lookup"] = False
self.hs = self.setup_test_homeserver(config=config)
return self.hs
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ def test_3pid_invite_disabled(self):
+ request, channel = self.make_request(
+ b"POST", "/createRoom", b"{}", access_token=self.tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ room_id = channel.json_body["room_id"]
+
+ params = {
+ "id_server": "testis",
+ "medium": "email",
+ "address": "test@example.com",
+ }
+ request_data = json.dumps(params)
+ request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
+ request, channel = self.make_request(
+ b"POST", request_url, request_data, access_token=self.tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
def test_3pid_lookup_disabled(self):
- self.hs.config.enable_3pid_lookup = False
+ url = (
+ "/_matrix/client/unstable/account/3pid/lookup"
+ "?id_server=testis&medium=email&address=foo@bar.baz"
+ )
+ request, channel = self.make_request("GET", url, access_token=self.tok)
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+
+ def test_3pid_bulk_lookup_disabled(self):
+ url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
+ data = {
+ "id_server": "testis",
+ "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
+ }
+ request_data = json.dumps(data)
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=self.tok
+ )
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"403", channel.result)
+
+
+class IdentityEnabledTestCase(unittest.HomeserverTestCase):
+ """Tests that 3PID lookup attempts succeed when the HS's config allows them."""
+
+ servlets = [
+ account.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
- self.register_user("kermit", "monkey")
- tok = self.login("kermit", "monkey")
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["enable_3pid_lookup"] = True
+ config["trusted_third_party_id_servers"] = ["testis"]
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.get_json.return_value = defer.succeed((200, "{}"))
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_handlers().identity_handler.http_client = mock_http_client
+
+ return self.hs
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ def test_3pid_invite_enabled(self):
request, channel = self.make_request(
- b"POST", "/createRoom", b"{}", access_token=tok
+ b"POST", "/createRoom", b"{}", access_token=self.tok
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
room_id = channel.json_body["room_id"]
+ # Replace the blacklisting SimpleHttpClient with our mock
+ self.hs.get_room_member_handler().simple_http_client = Mock(
+ spec=["get_json", "post_json_get_json"]
+ )
+ self.hs.get_room_member_handler().simple_http_client.get_json.return_value = defer.succeed(
+ (200, "{}")
+ )
+
params = {
"id_server": "testis",
"medium": "email",
@@ -58,7 +152,44 @@ class IdentityTestCase(unittest.HomeserverTestCase):
request_data = json.dumps(params)
request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
request, channel = self.make_request(
- b"POST", request_url, request_data, access_token=tok
+ b"POST", request_url, request_data, access_token=self.tok
)
self.render(request)
- self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ get_json = self.hs.get_handlers().identity_handler.http_client.get_json
+ get_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/lookup",
+ {"address": "test@example.com", "medium": "email"},
+ )
+
+ def test_3pid_lookup_enabled(self):
+ url = (
+ "/_matrix/client/unstable/account/3pid/lookup"
+ "?id_server=testis&medium=email&address=foo@bar.baz"
+ )
+ request, channel = self.make_request("GET", url, access_token=self.tok)
+ self.render(request)
+
+ get_json = self.hs.get_simple_http_client().get_json
+ get_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/lookup",
+ {"address": "foo@bar.baz", "medium": "email"},
+ )
+
+ def test_3pid_bulk_lookup_enabled(self):
+ url = "/_matrix/client/unstable/account/3pid/bulk_lookup"
+ data = {
+ "id_server": "testis",
+ "threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]],
+ }
+ request_data = json.dumps(data)
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=self.tok
+ )
+ self.render(request)
+
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+ post_json.assert_called_once_with(
+ "https://testis/_matrix/identity/api/v1/bulk_lookup",
+ {"threepids": [["email", "foo@bar.baz"], ["email", "john.doe@matrix.org"]]},
+ )
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 0b191d13c6..c7e287c61e 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -34,6 +34,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["default_room_version"] = "1"
config["retention"] = {
"enabled": True,
"default_policy": {
@@ -45,50 +46,63 @@ class RetentionTestCase(unittest.HomeserverTestCase):
}
self.hs = self.setup_test_homeserver(config=config)
+
return self.hs
def prepare(self, reactor, clock, homeserver):
self.user_id = self.register_user("user", "password")
self.token = self.login("user", "password")
- def test_retention_state_event(self):
- """Tests that the server configuration can limit the values a user can set to the
- room's retention policy.
+ self.store = self.hs.get_datastore()
+ self.serializer = self.hs.get_event_client_serializer()
+ self.clock = self.hs.get_clock()
+
+ def test_retention_event_purged_with_state_event(self):
+ """Tests that expired events are correctly purged when the room's retention policy
+ is defined by a state event.
"""
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ # Set the room's retention period to 2 days.
+ lifetime = one_day_ms * 2
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": one_day_ms * 4},
+ body={"max_lifetime": lifetime},
tok=self.token,
- expect_code=400,
)
+ self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+
+ def test_retention_event_purged_with_state_event_outside_allowed(self):
+ """Tests that the server configuration can override the policy for a room when
+ running the purge jobs.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+
+ # Set a max_lifetime higher than the maximum allowed value.
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": one_hour_ms},
+ body={"max_lifetime": one_day_ms * 4},
tok=self.token,
- expect_code=400,
)
- def test_retention_event_purged_with_state_event(self):
- """Tests that expired events are correctly purged when the room's retention policy
- is defined by a state event.
- """
- room_id = self.helper.create_room_as(self.user_id, tok=self.token)
+ # Check that the event is purged after waiting for the maximum allowed duration
+ # instead of the one specified in the room's policy.
+ self._test_retention_event_purged(room_id, one_day_ms * 1.5)
- # Set the room's retention period to 2 days.
- lifetime = one_day_ms * 2
+ # Set a max_lifetime lower than the minimum allowed value.
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
- body={"max_lifetime": lifetime},
+ body={"max_lifetime": one_hour_ms},
tok=self.token,
)
- self._test_retention_event_purged(room_id, one_day_ms * 1.5)
+ # Check that the event is purged after waiting for the minimum allowed duration
+ # instead of the one specified in the room's policy.
+ self._test_retention_event_purged(room_id, one_day_ms * 0.5)
def test_retention_event_purged_without_state_event(self):
"""Tests that expired events are correctly purged when the room's retention policy
@@ -140,7 +154,27 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# That event should be the second, not outdated event.
self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
- def _test_retention_event_purged(self, room_id, increment):
+ def _test_retention_event_purged(self, room_id: str, increment: float):
+ """Run the following test scenario to test the message retention policy support:
+
+ 1. Send event 1
+ 2. Increment time by `increment`
+ 3. Send event 2
+ 4. Increment time by `increment`
+ 5. Check that event 1 has been purged
+ 6. Check that event 2 has not been purged
+ 7. Check that state events that were sent before event 1 aren't purged.
+ The main reason for sending a second event is because currently Synapse won't
+ purge the latest message in a room because it would otherwise result in a lack of
+ forward extremities for this room. It's also a good thing to ensure the purge jobs
+ aren't too greedy and purge messages they shouldn't.
+
+ Args:
+ room_id: The ID of the room to test retention in.
+ increment: The number of milliseconds to advance the clock each time. Must be
+ defined so that events in the room aren't purged if they are `increment`
+ old but are purged if they are `increment * 2` old.
+ """
# Get the create event to, later, check that we can still access it.
message_handler = self.hs.get_message_handler()
create_event = self.get_success(
@@ -156,7 +190,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
expired_event_id = resp.get("event_id")
# Check that we can retrieve the event.
- expired_event = self.get_event(room_id, expired_event_id)
+ expired_event = self.get_event(expired_event_id)
self.assertEqual(
expired_event.get("content", {}).get("body"), "1", expired_event
)
@@ -174,26 +208,31 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# one should still be kept.
self.reactor.advance(increment / 1000)
- # Check that the event has been purged from the database.
- self.get_event(room_id, expired_event_id, expected_code=404)
+ # Check that the first event has been purged from the database, i.e. that we
+ # can't retrieve it anymore, because it has expired.
+ self.get_event(expired_event_id, expect_none=True)
- # Check that the event that hasn't been purged can still be retrieved.
- valid_event = self.get_event(room_id, valid_event_id)
+ # Check that the event that hasn't expired can still be retrieved.
+ valid_event = self.get_event(valid_event_id)
self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event)
# Check that we can still access state events that were sent before the event that
# has been purged.
self.get_event(room_id, create_event.event_id)
- def get_event(self, room_id, event_id, expected_code=200):
- url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+ def get_event(self, event_id, expect_none=False):
+ event = self.get_success(self.store.get_event(event_id, allow_none=True))
- request, channel = self.make_request("GET", url, access_token=self.token)
- self.render(request)
+ if expect_none:
+ self.assertIsNone(event)
+ return {}
- self.assertEqual(channel.code, expected_code, channel.result)
+ self.assertIsNotNone(event)
- return channel.json_body
+ time_now = self.clock.time_msec()
+ serialized = self.get_success(self.serializer.serialize_event(event, time_now))
+
+ return serialized
class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
@@ -205,6 +244,7 @@ class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
+ config["default_room_version"] = "1"
config["retention"] = {
"enabled": True,
}
diff --git a/tests/rest/client/test_room_access_rules.py b/tests/rest/client/test_room_access_rules.py
new file mode 100644
index 0000000000..de7856fba9
--- /dev/null
+++ b/tests/rest/client/test_room_access_rules.py
@@ -0,0 +1,1066 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import random
+import string
+from typing import Optional
+
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset
+from synapse.rest import admin
+from synapse.rest.client.v1 import directory, login, room
+from synapse.third_party_rules.access_rules import (
+ ACCESS_RULES_TYPE,
+ AccessRules,
+ RoomAccessRules,
+)
+from synapse.types import JsonDict, create_requester
+
+from tests import unittest
+
+
+class RoomAccessTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ directory.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ config["third_party_event_rules"] = {
+ "module": "synapse.third_party_rules.access_rules.RoomAccessRules",
+ "config": {
+ "domains_forbidden_when_restricted": ["forbidden_domain"],
+ "id_server": "testis",
+ },
+ }
+ config["trusted_third_party_id_servers"] = ["testis"]
+
+ def send_invite(destination, room_id, event_id, pdu):
+ return defer.succeed(pdu)
+
+ def get_json(uri, args={}, headers=None):
+ address_domain = args["address"].split("@")[1]
+ return defer.succeed({"hs": address_domain})
+
+ def post_json_get_json(uri, post_json, args={}, headers=None):
+ token = "".join(random.choice(string.ascii_letters) for _ in range(10))
+ return defer.succeed(
+ {
+ "token": token,
+ "public_keys": [
+ {
+ "public_key": "serverpublickey",
+ "key_validity_url": "https://testis/pubkey/isvalid",
+ },
+ {
+ "public_key": "phemeralpublickey",
+ "key_validity_url": "https://testis/pubkey/ephemeral/isvalid",
+ },
+ ],
+ "display_name": "f...@b...",
+ }
+ )
+
+ mock_federation_client = Mock(spec=["send_invite"])
+ mock_federation_client.send_invite.side_effect = send_invite
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"],)
+ # Mocking the response for /info on the IS API.
+ mock_http_client.get_json.side_effect = get_json
+ # Mocking the response for /store-invite on the IS API.
+ mock_http_client.post_json_get_json.side_effect = post_json_get_json
+ self.hs = self.setup_test_homeserver(
+ config=config,
+ federation_client=mock_federation_client,
+ simple_http_client=mock_http_client,
+ )
+
+ # TODO: This class does not use a singleton to get it's http client
+ # This should be fixed for easier testing
+ # https://github.com/matrix-org/synapse-dinsic/issues/26
+ self.hs.get_handlers().identity_handler.blacklisting_http_client = (
+ mock_http_client
+ )
+
+ self.third_party_event_rules = self.hs.get_third_party_event_rules()
+
+ return self.hs
+
+ def prepare(self, reactor, clock, homeserver):
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ self.restricted_room = self.create_room()
+ self.unrestricted_room = self.create_room(rule=AccessRules.UNRESTRICTED)
+ self.direct_rooms = [
+ self.create_room(direct=True),
+ self.create_room(direct=True),
+ self.create_room(direct=True),
+ ]
+
+ self.invitee_id = self.register_user("invitee", "test")
+ self.invitee_tok = self.login("invitee", "test")
+
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=self.invitee_id,
+ tok=self.tok,
+ )
+
+ def test_create_room_no_rule(self):
+ """Tests that creating a room with no rule will set the default."""
+ room_id = self.create_room()
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, AccessRules.RESTRICTED)
+
+ def test_create_room_direct_no_rule(self):
+ """Tests that creating a direct room with no rule will set the default."""
+ room_id = self.create_room(direct=True)
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, AccessRules.DIRECT)
+
+ def test_create_room_valid_rule(self):
+ """Tests that creating a room with a valid rule will set the right."""
+ room_id = self.create_room(rule=AccessRules.UNRESTRICTED)
+ rule = self.current_rule_in_room(room_id)
+
+ self.assertEqual(rule, AccessRules.UNRESTRICTED)
+
+ def test_create_room_invalid_rule(self):
+ """Tests that creating a room with an invalid rule will set fail."""
+ self.create_room(rule=AccessRules.DIRECT, expected_code=400)
+
+ def test_create_room_direct_invalid_rule(self):
+ """Tests that creating a direct room with an invalid rule will fail.
+ """
+ self.create_room(direct=True, rule=AccessRules.RESTRICTED, expected_code=400)
+
+ def test_create_room_default_power_level_rules(self):
+ """Tests that a room created with no power level overrides instead uses the dinum
+ defaults
+ """
+ room_id = self.create_room(direct=True, rule=AccessRules.DIRECT)
+ power_levels = self.helper.get_state(room_id, "m.room.power_levels", self.tok)
+
+ # Inviting another user should require PL50, even in private rooms
+ self.assertEqual(power_levels["invite"], 50)
+ # Sending arbitrary state events should require PL100
+ self.assertEqual(power_levels["state_default"], 100)
+
+ def test_create_room_fails_on_incorrect_power_level_rules(self):
+ """Tests that a room created with power levels lower than that required are rejected"""
+ modified_power_levels = RoomAccessRules._get_default_power_levels(self.user_id)
+ modified_power_levels["invite"] = 0
+ modified_power_levels["state_default"] = 50
+
+ self.create_room(
+ direct=True,
+ rule=AccessRules.DIRECT,
+ initial_state=[
+ {"type": "m.room.power_levels", "content": modified_power_levels}
+ ],
+ expected_code=400,
+ )
+
+ def test_existing_room_can_change_power_levels(self):
+ """Tests that a room created with default power levels can have their power levels
+ dropped after room creation
+ """
+ # Creates a room with the default power levels
+ room_id = self.create_room(
+ direct=True, rule=AccessRules.DIRECT, expected_code=200,
+ )
+
+ # Attempt to drop invite and state_default power levels after the fact
+ room_power_levels = self.helper.get_state(
+ room_id, "m.room.power_levels", self.tok
+ )
+ room_power_levels["invite"] = 0
+ room_power_levels["state_default"] = 50
+ self.helper.send_state(
+ room_id, "m.room.power_levels", room_power_levels, self.tok
+ )
+
+ def test_public_room(self):
+ """Tests that it's only possible to have a room listed in the public room list
+ if the access rule is restricted.
+ """
+ # Creating a room with the public_chat preset should succeed and set the access
+ # rule to restricted.
+ preset_room_id = self.create_room(preset=RoomCreationPreset.PUBLIC_CHAT)
+ self.assertEqual(
+ self.current_rule_in_room(preset_room_id), AccessRules.RESTRICTED
+ )
+
+ # Creating a room with the public join rule in its initial state should succeed
+ # and set the access rule to restricted.
+ init_state_room_id = self.create_room(
+ initial_state=[
+ {
+ "type": "m.room.join_rules",
+ "content": {"join_rule": JoinRules.PUBLIC},
+ }
+ ]
+ )
+ self.assertEqual(
+ self.current_rule_in_room(init_state_room_id), AccessRules.RESTRICTED
+ )
+
+ # List preset_room_id in the public room list
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/directory/list/room/%s" % (preset_room_id,),
+ {"visibility": "public"},
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # List init_state_room_id in the public room list
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/directory/list/room/%s" % (init_state_room_id,),
+ {"visibility": "public"},
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Changing access rule to unrestricted should fail.
+ self.change_rule_in_room(
+ preset_room_id, AccessRules.UNRESTRICTED, expected_code=403
+ )
+ self.change_rule_in_room(
+ init_state_room_id, AccessRules.UNRESTRICTED, expected_code=403
+ )
+
+ # Changing access rule to direct should fail.
+ self.change_rule_in_room(preset_room_id, AccessRules.DIRECT, expected_code=403)
+ self.change_rule_in_room(
+ init_state_room_id, AccessRules.DIRECT, expected_code=403
+ )
+
+ # Creating a new room with the public_chat preset and an access rule of direct
+ # should fail.
+ self.create_room(
+ preset=RoomCreationPreset.PUBLIC_CHAT,
+ rule=AccessRules.DIRECT,
+ expected_code=400,
+ )
+
+ # Changing join rule to public in an direct room should fail.
+ self.change_join_rule_in_room(
+ self.direct_rooms[0], JoinRules.PUBLIC, expected_code=403
+ )
+
+ def test_restricted(self):
+ """Tests that in restricted mode we're unable to invite users from blacklisted
+ servers but can invite other users.
+
+ Also tests that the room can be published to, and removed from, the public room
+ list.
+ """
+ # We can't invite a user from a forbidden HS.
+ self.helper.invite(
+ room=self.restricted_room,
+ src=self.user_id,
+ targ="@test:forbidden_domain",
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can invite a user which HS isn't forbidden.
+ self.helper.invite(
+ room=self.restricted_room,
+ src=self.user_id,
+ targ="@test:allowed_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can't send a 3PID invite to an address that is mapped to a forbidden HS.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.restricted_room,
+ expected_code=403,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to an HS that's not
+ # forbidden.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.restricted_room,
+ expected_code=200,
+ )
+
+ # We are allowed to publish the room to the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % self.restricted_room
+ data = {"visibility": "public"}
+
+ request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # We are allowed to remove the room from the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % self.restricted_room
+ data = {"visibility": "private"}
+
+ request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ def test_direct(self):
+ """Tests that, in direct mode, other users than the initial two can't be invited,
+ but the following scenario works:
+ * invited user joins the room
+ * invited user leaves the room
+ * room creator re-invites invited user
+
+ Tests that a user from a HS that's in the list of forbidden domains (to use
+ in restricted mode) can be invited.
+
+ Tests that the room cannot be published to the public room list.
+ """
+ not_invited_user = "@not_invited:forbidden_domain"
+
+ # We can't invite a new user to the room.
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=not_invited_user,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # The invited user can join the room.
+ self.helper.join(
+ room=self.direct_rooms[0],
+ user=self.invitee_id,
+ tok=self.invitee_tok,
+ expect_code=200,
+ )
+
+ # The invited user can leave the room.
+ self.helper.leave(
+ room=self.direct_rooms[0],
+ user=self.invitee_id,
+ tok=self.invitee_tok,
+ expect_code=200,
+ )
+
+ # The invited user can be re-invited to the room.
+ self.helper.invite(
+ room=self.direct_rooms[0],
+ src=self.user_id,
+ targ=self.invitee_id,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # If we're alone in the room and have always been the only member, we can invite
+ # someone.
+ self.helper.invite(
+ room=self.direct_rooms[1],
+ src=self.user_id,
+ targ=not_invited_user,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # Disable the 3pid invite ratelimiter
+ burst = self.hs.config.rc_third_party_invite.burst_count
+ per_second = self.hs.config.rc_third_party_invite.per_second
+ self.hs.config.rc_third_party_invite.burst_count = 10
+ self.hs.config.rc_third_party_invite.per_second = 0.1
+
+ # We can't send a 3PID invite to a room that already has two members.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.direct_rooms[0],
+ expected_code=403,
+ )
+
+ # We can't send a 3PID invite to a room that already has a pending invite.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.direct_rooms[1],
+ expected_code=403,
+ )
+
+ # We can send a 3PID invite to a room in which we've always been the only member.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.direct_rooms[2],
+ expected_code=200,
+ )
+
+ # We can send a 3PID invite to a room in which there's a 3PID invite.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.direct_rooms[2],
+ expected_code=403,
+ )
+
+ self.hs.config.rc_third_party_invite.burst_count = burst
+ self.hs.config.rc_third_party_invite.per_second = per_second
+
+ # We can't publish the room to the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % self.direct_rooms[0]
+ data = {"visibility": "public"}
+
+ request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.render(request)
+ self.assertEqual(channel.code, 403, channel.result)
+
+ def test_unrestricted(self):
+ """Tests that, in unrestricted mode, we can invite whoever we want, but we can
+ only change the power level of users that wouldn't be forbidden in restricted
+ mode.
+
+ Tests that the room cannot be published to the public room list.
+ """
+ # We can invite
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=self.user_id,
+ targ="@test:forbidden_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=self.user_id,
+ targ="@test:not_forbidden_domain",
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to a forbidden HS.
+ self.send_threepid_invite(
+ address="test@forbidden_domain",
+ room_id=self.unrestricted_room,
+ expected_code=200,
+ )
+
+ # We can send a 3PID invite to an address that is mapped to an HS that's not
+ # forbidden.
+ self.send_threepid_invite(
+ address="test@allowed_domain",
+ room_id=self.unrestricted_room,
+ expected_code=200,
+ )
+
+ # We can send a power level event that doesn't redefine the default PL or set a
+ # non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={"users": {self.user_id: 100, "@test:not_forbidden_domain": 10}},
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ # We can't send a power level event that redefines the default PL and doesn't set
+ # a non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={
+ "users": {self.user_id: 100, "@test:not_forbidden_domain": 10},
+ "users_default": 10,
+ },
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can't send a power level event that doesn't redefines the default PL but sets
+ # a non-default PL for a user that would be forbidden in restricted mode.
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.PowerLevels,
+ body={"users": {self.user_id: 100, "@test:forbidden_domain": 10}},
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ # We can't publish the room to the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % self.unrestricted_room
+ data = {"visibility": "public"}
+
+ request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.render(request)
+ self.assertEqual(channel.code, 403, channel.result)
+
+ def test_change_rules(self):
+ """Tests that we can only change the current rule from restricted to
+ unrestricted.
+ """
+ # We can't change the rule from restricted to direct.
+ self.change_rule_in_room(
+ room_id=self.restricted_room, new_rule=AccessRules.DIRECT, expected_code=403
+ )
+
+ # We can change the rule from restricted to unrestricted.
+ # Note that this changes self.restricted_room to an unrestricted room
+ self.change_rule_in_room(
+ room_id=self.restricted_room,
+ new_rule=AccessRules.UNRESTRICTED,
+ expected_code=200,
+ )
+
+ # We can't change the rule from unrestricted to restricted.
+ self.change_rule_in_room(
+ room_id=self.unrestricted_room,
+ new_rule=AccessRules.RESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't change the rule from unrestricted to direct.
+ self.change_rule_in_room(
+ room_id=self.unrestricted_room,
+ new_rule=AccessRules.DIRECT,
+ expected_code=403,
+ )
+
+ # We can't change the rule from direct to restricted.
+ self.change_rule_in_room(
+ room_id=self.direct_rooms[0],
+ new_rule=AccessRules.RESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't change the rule from direct to unrestricted.
+ self.change_rule_in_room(
+ room_id=self.direct_rooms[0],
+ new_rule=AccessRules.UNRESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't publish a room to the public room list and then change its rule to
+ # unrestricted
+
+ # Create a restricted room
+ test_room_id = self.create_room(rule=AccessRules.RESTRICTED)
+
+ # Publish the room to the public room list
+ url = "/_matrix/client/r0/directory/list/room/%s" % test_room_id
+ data = {"visibility": "public"}
+
+ request, channel = self.make_request("PUT", url, data, access_token=self.tok)
+ self.render(request)
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # Attempt to switch the room to "unrestricted"
+ self.change_rule_in_room(
+ room_id=test_room_id, new_rule=AccessRules.UNRESTRICTED, expected_code=403
+ )
+
+ # Attempt to switch the room to "direct"
+ self.change_rule_in_room(
+ room_id=test_room_id, new_rule=AccessRules.DIRECT, expected_code=403
+ )
+
+ def test_change_room_avatar(self):
+ """Tests that changing the room avatar is always allowed unless the room is a
+ direct chat, in which case it's forbidden.
+ """
+
+ avatar_content = {
+ "info": {"h": 398, "mimetype": "image/jpeg", "size": 31037, "w": 394},
+ "url": "mxc://example.org/JWEIFJgwEIhweiWJE",
+ }
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.RoomAvatar,
+ body=avatar_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_room_name(self):
+ """Tests that changing the room name is always allowed unless the room is a direct
+ chat, in which case it's forbidden.
+ """
+
+ name_content = {"name": "My super room"}
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.Name,
+ body=name_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_change_room_topic(self):
+ """Tests that changing the room topic is always allowed unless the room is a
+ direct chat, in which case it's forbidden.
+ """
+
+ topic_content = {"topic": "Welcome to this room"}
+
+ self.helper.send_state(
+ room_id=self.restricted_room,
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.unrestricted_room,
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=200,
+ )
+
+ self.helper.send_state(
+ room_id=self.direct_rooms[0],
+ event_type=EventTypes.Topic,
+ body=topic_content,
+ tok=self.tok,
+ expect_code=403,
+ )
+
+ def test_revoke_3pid_invite_direct(self):
+ """Tests that revoking a 3PID invite doesn't cause the room access rules module to
+ confuse the revokation as a new 3PID invite.
+ """
+ invite_token = "sometoken"
+
+ invite_body = {
+ "display_name": "ker...@exa...",
+ "public_keys": [
+ {
+ "key_validity_url": "https://validity_url",
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
+ },
+ {
+ "key_validity_url": "https://validity_url",
+ "public_key": "4_9nzEeDwR5N9s51jPodBiLnqH43A2_g2InVT137t9I",
+ },
+ ],
+ "key_validity_url": "https://validity_url",
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA",
+ }
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body=invite_body,
+ tok=self.tok,
+ )
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body={},
+ tok=self.tok,
+ )
+
+ invite_token = "someothertoken"
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body=invite_body,
+ tok=self.tok,
+ )
+
+ def test_check_event_allowed(self):
+ """Tests that RoomAccessRules.check_event_allowed behaves accordingly.
+
+ It tests that:
+ * forbidden users cannot join restricted rooms.
+ * forbidden users can only join unrestricted rooms if they have an invite.
+ """
+ event_creator = self.hs.get_event_creation_handler()
+
+ # Test that forbidden users cannot join restricted rooms
+ requester = create_requester(self.user_id)
+ allowed_requester = create_requester("@user:allowed_domain")
+ forbidden_requester = create_requester("@user:forbidden_domain")
+
+ # Create a join event for a forbidden user
+ forbidden_join_event, forbidden_join_event_context = self.get_success(
+ event_creator.create_event(
+ forbidden_requester,
+ {
+ "type": EventTypes.Member,
+ "room_id": self.restricted_room,
+ "sender": forbidden_requester.user.to_string(),
+ "content": {"membership": Membership.JOIN},
+ "state_key": forbidden_requester.user.to_string(),
+ },
+ )
+ )
+
+ # Create a join event for an allowed user
+ allowed_join_event, allowed_join_event_context = self.get_success(
+ event_creator.create_event(
+ allowed_requester,
+ {
+ "type": EventTypes.Member,
+ "room_id": self.restricted_room,
+ "sender": allowed_requester.user.to_string(),
+ "content": {"membership": Membership.JOIN},
+ "state_key": allowed_requester.user.to_string(),
+ },
+ )
+ )
+
+ # Assert a join event from a forbidden user to a restricted room is rejected
+ can_join = self.get_success(
+ self.third_party_event_rules.check_event_allowed(
+ forbidden_join_event, forbidden_join_event_context
+ )
+ )
+ self.assertFalse(can_join)
+
+ # But a join event from an non-forbidden user to a restricted room is allowed
+ can_join = self.get_success(
+ self.third_party_event_rules.check_event_allowed(
+ allowed_join_event, allowed_join_event_context
+ )
+ )
+ self.assertTrue(can_join)
+
+ # Test that forbidden users can only join unrestricted rooms if they have an invite
+
+ # Recreate the forbidden join event for the unrestricted room instead
+ forbidden_join_event, forbidden_join_event_context = self.get_success(
+ event_creator.create_event(
+ forbidden_requester,
+ {
+ "type": EventTypes.Member,
+ "room_id": self.unrestricted_room,
+ "sender": forbidden_requester.user.to_string(),
+ "content": {"membership": Membership.JOIN},
+ "state_key": forbidden_requester.user.to_string(),
+ },
+ )
+ )
+
+ # A forbidden user without an invite should not be able to join an unrestricted room
+ can_join = self.get_success(
+ self.third_party_event_rules.check_event_allowed(
+ forbidden_join_event, forbidden_join_event_context
+ )
+ )
+ self.assertFalse(can_join)
+
+ # However, if we then invite this user...
+ self.helper.invite(
+ room=self.unrestricted_room,
+ src=requester.user.to_string(),
+ targ=forbidden_requester.user.to_string(),
+ tok=self.tok,
+ )
+
+ # And create another join event, making sure that its context states it's coming
+ # in after the above invite was made...
+ forbidden_join_event, forbidden_join_event_context = self.get_success(
+ event_creator.create_event(
+ forbidden_requester,
+ {
+ "type": EventTypes.Member,
+ "room_id": self.unrestricted_room,
+ "sender": forbidden_requester.user.to_string(),
+ "content": {"membership": Membership.JOIN},
+ "state_key": forbidden_requester.user.to_string(),
+ },
+ )
+ )
+
+ # Then the forbidden user should be able to join!
+ can_join = self.get_success(
+ self.third_party_event_rules.check_event_allowed(
+ forbidden_join_event, forbidden_join_event_context
+ )
+ )
+ self.assertTrue(can_join)
+
+ def test_freezing_a_room(self):
+ """Tests that the power levels in a room change to prevent new events from
+ non-admin users when the last admin of a room leaves.
+ """
+
+ def freeze_room_with_id_and_power_levels(
+ room_id: str, custom_power_levels_content: Optional[JsonDict] = None,
+ ):
+ # Invite a user to the room, they join with PL 0
+ self.helper.invite(
+ room=room_id, src=self.user_id, targ=self.invitee_id, tok=self.tok,
+ )
+
+ # Invitee joins the room
+ self.helper.join(
+ room=room_id, user=self.invitee_id, tok=self.invitee_tok,
+ )
+
+ if not custom_power_levels_content:
+ # Retrieve the room's current power levels event content
+ power_levels = self.helper.get_state(
+ room_id=room_id, event_type="m.room.power_levels", tok=self.tok,
+ )
+ else:
+ power_levels = custom_power_levels_content
+
+ # Override the room's power levels with the given power levels content
+ self.helper.send_state(
+ room_id=room_id,
+ event_type="m.room.power_levels",
+ body=custom_power_levels_content,
+ tok=self.tok,
+ )
+
+ # Ensure that the invitee leaving the room does not change the power levels
+ self.helper.leave(
+ room=room_id, user=self.invitee_id, tok=self.invitee_tok,
+ )
+
+ # Retrieve the new power levels of the room
+ new_power_levels = self.helper.get_state(
+ room_id=room_id, event_type="m.room.power_levels", tok=self.tok,
+ )
+
+ # Ensure they have not changed
+ self.assertDictEqual(power_levels, new_power_levels)
+
+ # Invite the user back again
+ self.helper.invite(
+ room=room_id, src=self.user_id, targ=self.invitee_id, tok=self.tok,
+ )
+
+ # Invitee joins the room
+ self.helper.join(
+ room=room_id, user=self.invitee_id, tok=self.invitee_tok,
+ )
+
+ # Now the admin leaves the room
+ self.helper.leave(
+ room=room_id, user=self.user_id, tok=self.tok,
+ )
+
+ # Check the power levels again
+ new_power_levels = self.helper.get_state(
+ room_id=room_id, event_type="m.room.power_levels", tok=self.invitee_tok,
+ )
+
+ # Ensure that the new power levels prevent anyone but admins from sending
+ # certain events
+ self.assertEquals(new_power_levels["state_default"], 100)
+ self.assertEquals(new_power_levels["events_default"], 100)
+ self.assertEquals(new_power_levels["kick"], 100)
+ self.assertEquals(new_power_levels["invite"], 100)
+ self.assertEquals(new_power_levels["ban"], 100)
+ self.assertEquals(new_power_levels["redact"], 100)
+ self.assertDictEqual(new_power_levels["events"], {})
+ self.assertDictEqual(new_power_levels["users"], {self.user_id: 100})
+
+ # Ensure new users entering the room aren't going to immediately become admins
+ self.assertEquals(new_power_levels["users_default"], 0)
+
+ # Test that freezing a room with the default power level state event content works
+ room1 = self.create_room()
+ freeze_room_with_id_and_power_levels(room1)
+
+ # Test that freezing a room with a power level state event that is missing
+ # `state_default` and `event_default` keys behaves as expected
+ room2 = self.create_room()
+ freeze_room_with_id_and_power_levels(
+ room2,
+ {
+ "ban": 50,
+ "events": {
+ "m.room.avatar": 50,
+ "m.room.canonical_alias": 50,
+ "m.room.history_visibility": 100,
+ "m.room.name": 50,
+ "m.room.power_levels": 100,
+ },
+ "invite": 0,
+ "kick": 50,
+ "redact": 50,
+ "users": {self.user_id: 100},
+ "users_default": 0,
+ # Explicitly remove `state_default` and `event_default` keys
+ },
+ )
+
+ # Test that freezing a room with a power level state event that is *additionally*
+ # missing `ban`, `invite`, `kick` and `redact` keys behaves as expected
+ room3 = self.create_room()
+ freeze_room_with_id_and_power_levels(
+ room3,
+ {
+ "events": {
+ "m.room.avatar": 50,
+ "m.room.canonical_alias": 50,
+ "m.room.history_visibility": 100,
+ "m.room.name": 50,
+ "m.room.power_levels": 100,
+ },
+ "users": {self.user_id: 100},
+ "users_default": 0,
+ # Explicitly remove `state_default` and `event_default` keys
+ # Explicitly remove `ban`, `invite`, `kick` and `redact` keys
+ },
+ )
+
+ def create_room(
+ self,
+ direct=False,
+ rule=None,
+ preset=RoomCreationPreset.TRUSTED_PRIVATE_CHAT,
+ initial_state=None,
+ expected_code=200,
+ ):
+ content = {"is_direct": direct, "preset": preset}
+
+ if rule:
+ content["initial_state"] = [
+ {"type": ACCESS_RULES_TYPE, "state_key": "", "content": {"rule": rule}}
+ ]
+
+ if initial_state:
+ if "initial_state" not in content:
+ content["initial_state"] = []
+
+ content["initial_state"] += initial_state
+
+ request, channel = self.make_request(
+ "POST", "/_matrix/client/r0/createRoom", content, access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ if expected_code == 200:
+ return channel.json_body["room_id"]
+
+ def current_rule_in_room(self, room_id):
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, 200, channel.result)
+ return channel.json_body["rule"]
+
+ def change_rule_in_room(self, room_id, new_rule, expected_code=200):
+ data = {"rule": new_rule}
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
+ json.dumps(data),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def change_join_rule_in_room(self, room_id, new_join_rule, expected_code=200):
+ data = {"join_rule": new_join_rule}
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, EventTypes.JoinRules),
+ json.dumps(data),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def send_threepid_invite(self, address, room_id, expected_code=200):
+ params = {"id_server": "testis", "medium": "email", "address": address}
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/invite" % room_id,
+ json.dumps(params),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def send_state_with_state_key(
+ self, room_id, event_type, state_key, body, tok, expect_code=200
+ ):
+ path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % (
+ room_id,
+ event_type,
+ state_key,
+ )
+
+ request, channel = self.make_request(
+ "PUT", path, json.dumps(body), access_token=tok
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expect_code, channel.result)
+
+ return channel.json_body
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
new file mode 100644
index 0000000000..d03e121664
--- /dev/null
+++ b/tests/rest/client/test_third_party_rules.py
@@ -0,0 +1,170 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import threading
+from typing import Dict
+
+from mock import Mock
+
+from synapse.events import EventBase
+from synapse.module_api import ModuleApi
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.types import Requester, StateMap
+
+from tests import unittest
+
+thread_local = threading.local()
+
+
+class ThirdPartyRulesTestModule:
+ def __init__(self, config: Dict, module_api: ModuleApi):
+ # keep a record of the "current" rules module, so that the test can patch
+ # it if desired.
+ thread_local.rules_module = self
+ self.module_api = module_api
+
+ async def on_create_room(
+ self, requester: Requester, config: dict, is_requester_admin: bool
+ ):
+ return True
+
+ async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
+ return True
+
+ @staticmethod
+ def parse_config(config):
+ return config
+
+
+def current_rules_module() -> ThirdPartyRulesTestModule:
+ return thread_local.rules_module
+
+
+class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def default_config(self):
+ config = super().default_config()
+ config["third_party_event_rules"] = {
+ "module": __name__ + ".ThirdPartyRulesTestModule",
+ "config": {},
+ }
+ return config
+
+ def prepare(self, reactor, clock, homeserver):
+ # Create a user and room to play with during the tests
+ self.user_id = self.register_user("kermit", "monkey")
+ self.tok = self.login("kermit", "monkey")
+
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+ def test_third_party_rules(self):
+ """Tests that a forbidden event is forbidden from being sent, but an allowed one
+ can be sent.
+ """
+ # patch the rules module with a Mock which will return False for some event
+ # types
+ async def check(ev, state):
+ return ev.type != "foo.bar.forbidden"
+
+ callback = Mock(spec=[], side_effect=check)
+ current_rules_module().check_event_allowed = callback
+
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id,
+ {},
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ callback.assert_called_once()
+
+ # there should be various state events in the state arg: do some basic checks
+ state_arg = callback.call_args[0][1]
+ for k in (("m.room.create", ""), ("m.room.member", self.user_id)):
+ self.assertIn(k, state_arg)
+ ev = state_arg[k]
+ self.assertEqual(ev.type, k[0])
+ self.assertEqual(ev.state_key, k[1])
+
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % self.room_id,
+ {},
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"403", channel.result)
+
+ def test_modify_event(self):
+ """Tests that the module can successfully tweak an event before it is persisted.
+ """
+ # first patch the event checker so that it will modify the event
+ async def check(ev: EventBase, state):
+ ev.content = {"x": "y"}
+ return True
+
+ current_rules_module().check_event_allowed = check
+
+ # now send the event
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
+ {"x": "x"},
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ event_id = channel.json_body["event_id"]
+
+ # ... and check that it got modified
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
+ access_token=self.tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ ev = channel.json_body
+ self.assertEqual(ev["content"]["x"], "y")
+
+ def test_send_event(self):
+ """Tests that the module can send an event into a room via the module api"""
+ content = {
+ "msgtype": "m.text",
+ "body": "Hello!",
+ }
+ event_dict = {
+ "room_id": self.room_id,
+ "type": "m.room.message",
+ "content": content,
+ "sender": self.user_id,
+ }
+ event = self.get_success(
+ current_rules_module().module_api.create_and_send_event_into_room(
+ event_dict
+ )
+ ) # type: EventBase
+
+ self.assertEquals(event.sender, self.user_id)
+ self.assertEquals(event.room_id, self.room_id)
+ self.assertEquals(event.type, "m.room.message")
+ self.assertEquals(event.content, content)
diff --git a/tests/rest/client/third_party_rules.py b/tests/rest/client/third_party_rules.py
deleted file mode 100644
index 7167fc56b6..0000000000
--- a/tests/rest/client/third_party_rules.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License');
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.rest import admin
-from synapse.rest.client.v1 import login, room
-
-from tests import unittest
-
-
-class ThirdPartyRulesTestModule(object):
- def __init__(self, config):
- pass
-
- def check_event_allowed(self, event, context):
- if event.type == "foo.bar.forbidden":
- return False
- else:
- return True
-
- @staticmethod
- def parse_config(config):
- return config
-
-
-class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
- servlets = [
- admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def make_homeserver(self, reactor, clock):
- config = self.default_config()
- config["third_party_event_rules"] = {
- "module": "tests.rest.client.third_party_rules.ThirdPartyRulesTestModule",
- "config": {},
- }
-
- self.hs = self.setup_test_homeserver(config=config)
- return self.hs
-
- def test_third_party_rules(self):
- """Tests that a forbidden event is forbidden from being sent, but an allowed one
- can be sent.
- """
- user_id = self.register_user("kermit", "monkey")
- tok = self.login("kermit", "monkey")
-
- room_id = self.helper.create_room_as(user_id, tok=tok)
-
- request, channel = self.make_request(
- "PUT",
- "/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % room_id,
- {},
- access_token=tok,
- )
- self.render(request)
- self.assertEquals(channel.result["code"], b"200", channel.result)
-
- request, channel = self.make_request(
- "PUT",
- "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % room_id,
- {},
- access_token=tok,
- )
- self.render(request)
- self.assertEquals(channel.result["code"], b"403", channel.result)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index db52725cfe..2668662c9e 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -62,8 +62,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -76,14 +75,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -111,8 +109,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -132,7 +129,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -160,8 +156,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request_data = json.dumps(params)
- request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+ request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -174,14 +169,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
- request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index e74bddc1e5..68c4a6a8f7 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -21,13 +21,13 @@
import json
from urllib import parse as urlparse
-from mock import Mock
+from mock import Mock, patch
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
from synapse.rest.client.v1 import directory, login, profile, room
-from synapse.rest.client.v2_alpha import account
+from synapse.rest.client.v2_alpha import account, room_upgrade_rest_servlet
from synapse.types import JsonDict, RoomAlias, UserID
from synapse.util.stringutils import random_string
@@ -684,38 +684,39 @@ class RoomJoinRatelimitTestCase(RoomBase):
]
@unittest.override_config(
- {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
def test_join_local_ratelimit(self):
"""Tests that local joins are actually rate-limited."""
- for i in range(5):
+ for i in range(3):
self.helper.create_room_as(self.user_id)
self.helper.create_room_as(self.user_id, expect_code=429)
@unittest.override_config(
- {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
def test_join_local_ratelimit_profile_change(self):
"""Tests that sending a profile update into all of the user's joined rooms isn't
rate-limited by the rate-limiter on joins."""
- # Create and join more rooms than the rate-limiting config allows in a second.
+ # Create and join as many rooms as the rate-limiting config allows in a second.
room_ids = [
self.helper.create_room_as(self.user_id),
self.helper.create_room_as(self.user_id),
self.helper.create_room_as(self.user_id),
]
- self.reactor.advance(1)
- room_ids = room_ids + [
- self.helper.create_room_as(self.user_id),
- self.helper.create_room_as(self.user_id),
- self.helper.create_room_as(self.user_id),
- ]
+ # Let some time for the rate-limiter to forget about our multi-join.
+ self.reactor.advance(2)
+ # Add one to make sure we're joined to more rooms than the config allows us to
+ # join in a second.
+ room_ids.append(self.helper.create_room_as(self.user_id))
# Create a profile for the user, since it hasn't been done on registration.
store = self.hs.get_datastore()
- store.create_profile(UserID.from_string(self.user_id).localpart)
+ self.get_success(
+ store.create_profile(UserID.from_string(self.user_id).localpart)
+ )
# Update the display name for the user.
path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
@@ -738,7 +739,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
self.assertEquals(channel.json_body["displayname"], "John Doe")
@unittest.override_config(
- {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
+ {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
def test_join_local_ratelimit_idempotent(self):
"""Tests that the room join endpoints remain idempotent despite rate-limiting
@@ -754,7 +755,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
for path in paths_to_test:
# Make sure we send more requests than the rate-limiting config would allow
# if all of these requests ended up joining the user to a room.
- for i in range(6):
+ for i in range(4):
request, channel = self.make_request("POST", path % room_id, {})
self.render(request)
self.assertEquals(channel.code, 200)
@@ -2059,3 +2060,158 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
"""An alias which does not point to the room raises a SynapseError."""
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
+
+
+# To avoid the tests timing out don't add a delay to "annoy the requester".
+@patch("random.randint", new=lambda a, b: 0)
+class ShadowBannedTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ directory.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ room_upgrade_rest_servlet.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.banned_user_id = self.register_user("banned", "test")
+ self.banned_access_token = self.login("banned", "test")
+
+ self.store = self.hs.get_datastore()
+
+ self.get_success(
+ self.store.db_pool.simple_update(
+ table="users",
+ keyvalues={"name": self.banned_user_id},
+ updatevalues={"shadow_banned": True},
+ desc="shadow_ban",
+ )
+ )
+
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+ def test_invite(self):
+ """Invites from shadow-banned users don't actually get sent."""
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # Inviting the user completes successfully.
+ self.helper.invite(
+ room=room_id,
+ src=self.banned_user_id,
+ tok=self.banned_access_token,
+ targ=self.other_user_id,
+ )
+
+ # But the user wasn't actually invited.
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(self.other_user_id)
+ )
+ self.assertEqual(invited_rooms, [])
+
+ def test_invite_3pid(self):
+ """Ensure that a 3PID invite does not attempt to contact the identity server."""
+ identity_handler = self.hs.get_handlers().identity_handler
+ identity_handler.lookup_3pid = Mock(
+ side_effect=AssertionError("This should not get called")
+ )
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ # Inviting the user completes successfully.
+ request, channel = self.make_request(
+ "POST",
+ "/rooms/%s/invite" % (room_id,),
+ {"id_server": "test", "medium": "email", "address": "test@test.test"},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ # This should have raised an error earlier, but double check this wasn't called.
+ identity_handler.lookup_3pid.assert_not_called()
+
+ def test_create_room(self):
+ """Invitations during a room creation should be discarded, but the room still gets created."""
+ # The room creation is successful.
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/createRoom",
+ {"visibility": "public", "invite": [self.other_user_id]},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ room_id = channel.json_body["room_id"]
+
+ # But the user wasn't actually invited.
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(self.other_user_id)
+ )
+ self.assertEqual(invited_rooms, [])
+
+ # Since a real room was created, the other user should be able to join it.
+ self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
+
+ # Both users should be in the room.
+ users = self.get_success(self.store.get_users_in_room(room_id))
+ self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
+
+ def test_message(self):
+ """Messages from shadow-banned users don't actually get sent."""
+
+ room_id = self.helper.create_room_as(
+ self.other_user_id, tok=self.other_access_token
+ )
+
+ # The user should be in the room.
+ self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token)
+
+ # Sending a message should complete successfully.
+ result = self.helper.send_event(
+ room_id=room_id,
+ type=EventTypes.Message,
+ content={"msgtype": "m.text", "body": "with right label"},
+ tok=self.banned_access_token,
+ )
+ self.assertIn("event_id", result)
+ event_id = result["event_id"]
+
+ latest_events = self.get_success(
+ self.store.get_latest_event_ids_in_room(room_id)
+ )
+ self.assertNotIn(event_id, latest_events)
+
+ def test_upgrade(self):
+ """A room upgrade should fail, but look like it succeeded."""
+
+ # The create works fine.
+ room_id = self.helper.create_room_as(
+ self.banned_user_id, tok=self.banned_access_token
+ )
+
+ request, channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
+ {"new_version": "6"},
+ access_token=self.banned_access_token,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+ # A new room_id should be returned.
+ self.assertIn("replacement_room", channel.json_body)
+
+ new_room_id = channel.json_body["replacement_room"]
+
+ # It doesn't really matter what API we use here, we just want to assert
+ # that the room doesn't exist.
+ summary = self.get_success(self.store.get_room_summary(new_room_id))
+ # The summary should be empty since the room doesn't exist.
+ self.assertEqual(summary, {})
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 152a5182fa..0a51aeff92 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -14,11 +14,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import json
import os
import re
from email.parser import Parser
+from typing import Optional
import pkg_resources
@@ -29,6 +29,7 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register
from tests import unittest
+from tests.unittest import override_config
class PasswordResetTestCase(unittest.HomeserverTestCase):
@@ -668,16 +669,104 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
- def _request_token(self, email, client_secret):
+ @override_config({"next_link_domain_whitelist": None})
+ def test_next_link(self):
+ """Tests a valid next_link parameter value with no whitelist (good case)"""
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://example.com/a/good/site",
+ expect_code=200,
+ )
+
+ @override_config({"next_link_domain_whitelist": None})
+ def test_next_link_exotic_protocol(self):
+ """Tests using a esoteric protocol as a next_link parameter value.
+ Someone may be hosting a client on IPFS etc.
+ """
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="some-protocol://abcdefghijklmopqrstuvwxyz",
+ expect_code=200,
+ )
+
+ @override_config({"next_link_domain_whitelist": None})
+ def test_next_link_file_uri(self):
+ """Tests next_link parameters cannot be file URI"""
+ # Attempt to use a next_link value that points to the local disk
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="file:///host/path",
+ expect_code=400,
+ )
+
+ @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
+ def test_next_link_domain_whitelist(self):
+ """Tests next_link parameters must fit the whitelist if provided"""
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://example.com/some/good/page",
+ expect_code=200,
+ )
+
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://example.org/some/also/good/page",
+ expect_code=200,
+ )
+
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://bad.example.org/some/bad/page",
+ expect_code=400,
+ )
+
+ @override_config({"next_link_domain_whitelist": []})
+ def test_empty_next_link_domain_whitelist(self):
+ """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially
+ disallowed
+ """
+ self._request_token(
+ "something@example.com",
+ "some_secret",
+ next_link="https://example.com/a/page",
+ expect_code=400,
+ )
+
+ def _request_token(
+ self,
+ email: str,
+ client_secret: str,
+ next_link: Optional[str] = None,
+ expect_code: int = 200,
+ ) -> str:
+ """Request a validation token to add an email address to a user's account
+
+ Args:
+ email: The email address to validate
+ client_secret: A secret string
+ next_link: A link to redirect the user to after validation
+ expect_code: Expected return code of the call
+
+ Returns:
+ The ID of the new threepid validation session
+ """
+ body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
+ if next_link:
+ body["next_link"] = next_link
+
request, channel = self.make_request(
- "POST",
- b"account/3pid/email/requestToken",
- {"client_secret": client_secret, "email": email, "send_attempt": 1},
+ "POST", b"account/3pid/email/requestToken", body,
)
self.render(request)
- self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(expect_code, channel.code, channel.result)
- return channel.json_body["sid"]
+ return channel.json_body.get("sid")
def _request_token_invalid_email(
self, email, expected_errcode, expected_error, client_secret="foobar",
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 53a43038f0..ecf697e5e0 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -19,8 +19,12 @@ import datetime
import json
import os
+from mock import Mock
+
import pkg_resources
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
@@ -87,14 +91,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid password")
- def test_POST_bad_username(self):
- request_data = json.dumps({"username": 777, "password": "monkey"})
- request, channel = self.make_request(b"POST", self.url, request_data)
- self.render(request)
-
- self.assertEquals(channel.result["code"], b"400", channel.result)
- self.assertEquals(channel.json_body["error"], "Invalid username")
-
def test_POST_user_valid(self):
user_id = "@kermit:test"
device_id = "frogfone"
@@ -160,7 +156,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
@@ -186,7 +182,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
- self.reactor.advance(retry_after_ms / 1000.0)
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
@@ -303,6 +299,47 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertIsNotNone(channel.json_body.get("sid"))
+class RegisterHideProfileTestCase(unittest.HomeserverTestCase):
+
+ servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
+
+ def make_homeserver(self, reactor, clock):
+
+ self.url = b"/_matrix/client/r0/register"
+
+ config = self.default_config()
+ config["enable_registration"] = True
+ config["show_users_in_user_directory"] = False
+ config["replicate_user_profiles_to"] = ["fakeserver"]
+
+ mock_http_client = Mock(spec=["get_json", "post_json_get_json"])
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ return self.hs
+
+ def test_profile_hidden(self):
+ user_id = self.register_user("kermit", "monkey")
+
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+
+ # We expect post_json_get_json to have been called twice: once with the original
+ # profile and once with the None profile resulting from the request to hide it
+ # from the user directory.
+ self.assertEqual(post_json.call_count, 2, post_json.call_args_list)
+
+ # Get the args (and not kwargs) passed to post_json.
+ args = post_json.call_args[0]
+ # Make sure the last call was attempting to replicate profiles.
+ split_uri = args[0].split("/")
+ self.assertEqual(split_uri[len(split_uri) - 1], "replicate_profiles", args[0])
+ # Make sure the last profile update was overriding the user's profile to None.
+ self.assertEqual(args[1]["batch"][user_id], None, args[1])
+
+
class AccountValidityTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -312,6 +349,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
sync.register_servlets,
logout.register_servlets,
account_validity.register_servlets,
+ account.register_servlets,
]
def make_homeserver(self, reactor, clock):
@@ -437,6 +475,155 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+class AccountValidityUserDirectoryTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.client.v1.profile.register_servlets,
+ synapse.rest.client.v1.room.register_servlets,
+ synapse.rest.client.v2_alpha.user_directory.register_servlets,
+ login.register_servlets,
+ register.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ account_validity.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+
+ # Set accounts to expire after a week
+ config["enable_registration"] = True
+ config["account_validity"] = {
+ "enabled": True,
+ "period": 604800000, # Time in ms for 1 week
+ }
+ config["replicate_user_profiles_to"] = "test.is"
+
+ # Mock homeserver requests to an identity server
+ mock_http_client = Mock(spec=["post_json_get_json"])
+ mock_http_client.post_json_get_json.return_value = defer.succeed((200, "{}"))
+
+ self.hs = self.setup_test_homeserver(
+ config=config, simple_http_client=mock_http_client
+ )
+
+ return self.hs
+
+ def test_expired_user_in_directory(self):
+ """Test that an expired user is hidden in the user directory"""
+ # Create an admin user to search the user directory
+ admin_id = self.register_user("admin", "adminpassword", admin=True)
+ admin_tok = self.login("admin", "adminpassword")
+
+ # Ensure the admin never expires
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": admin_id,
+ "expiration_ts": 999999999999,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Mock the homeserver's HTTP client
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+
+ # Create a user
+ username = "kermit"
+ user_id = self.register_user(username, "monkey")
+ self.login(username, "monkey")
+ self.get_success(
+ self.hs.get_datastore().set_profile_displayname(username, "mr.kermit", 1)
+ )
+
+ # Check that a full profile for this user is replicated
+ self.assertIsNotNone(post_json.call_args, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+
+ self.assertIsNotNone(batch, batch)
+ self.assertEquals(len(batch), 1, batch)
+
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's not None
+ replicated_content = batch[user_id]
+ self.assertIsNotNone(replicated_content)
+
+ # Expire the user
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ # Wait for the background job to run which hides expired users in the directory
+ self.reactor.advance(60 * 60 * 1000)
+
+ # Check if the homeserver has replicated the user's profile to the identity server
+ self.assertIsNotNone(post_json.call_args, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+
+ self.assertIsNotNone(batch, batch)
+ self.assertEquals(len(batch), 1, batch)
+
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's None, signifying that the user should be removed from the user
+ # directory because they were expired
+ replicated_content = batch[user_id]
+ self.assertIsNone(replicated_content)
+
+ # Now renew the user, and check they get replicated again to the identity server
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ params = {
+ "user_id": user_id,
+ "expiration_ts": 99999999999,
+ "enable_renewal_emails": False,
+ }
+ request_data = json.dumps(params)
+ request, channel = self.make_request(
+ b"POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.pump(10)
+ self.reactor.advance(10)
+ self.pump()
+
+ # Check if the homeserver has replicated the user's profile to the identity server
+ post_json = self.hs.get_simple_http_client().post_json_get_json
+ self.assertNotEquals(post_json.call_args, None, post_json.call_args)
+ payload = post_json.call_args[0][1]
+ batch = payload.get("batch")
+ self.assertNotEquals(batch, None, batch)
+ self.assertEquals(len(batch), 1, batch)
+ replicated_user_id = list(batch.keys())[0]
+ self.assertEquals(replicated_user_id, user_id, replicated_user_id)
+
+ # There was replicated information about our user
+ # Check that it's not None, signifying that the user is back in the user
+ # directory
+ replicated_content = batch[user_id]
+ self.assertIsNotNone(replicated_content)
+
+
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -587,7 +774,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
"POST", "account/deactivate", request_data, access_token=tok
)
self.render(request)
- self.assertEqual(request.code, 200)
+ self.assertEqual(request.code, 200, channel.result)
self.reactor.advance(datetime.timedelta(days=8).total_seconds())
|