diff --git a/tests/rest/client/test_room_access_rules.py b/tests/rest/client/test_room_access_rules.py
index 4b8359762a..5c46dabf32 100644
--- a/tests/rest/client/test_room_access_rules.py
+++ b/tests/rest/client/test_room_access_rules.py
@@ -311,6 +311,52 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
expected_code=200,
)
+ def test_change_rules(self):
+ """Tests that we can only change the current rule from restricted to
+ unrestricted.
+ """
+ # We can change the rule from restricted to unrestricted.
+ self.change_rule_in_room(
+ room_id=self.restricted_room,
+ new_rule=ACCESS_RULE_UNRESTRICTED,
+ expected_code=200,
+ )
+
+ # We can't change the rule from restricted to direct.
+ self.change_rule_in_room(
+ room_id=self.restricted_room,
+ new_rule=ACCESS_RULE_DIRECT,
+ expected_code=403,
+ )
+
+ # We can't change the rule from unrestricted to restricted.
+ self.change_rule_in_room(
+ room_id=self.unrestricted_room,
+ new_rule=ACCESS_RULE_RESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't change the rule from unrestricted to direct.
+ self.change_rule_in_room(
+ room_id=self.unrestricted_room,
+ new_rule=ACCESS_RULE_DIRECT,
+ expected_code=403,
+ )
+
+ # We can't change the rule from direct to restricted.
+ self.change_rule_in_room(
+ room_id=self.direct_rooms[0],
+ new_rule=ACCESS_RULE_RESTRICTED,
+ expected_code=403,
+ )
+
+ # We can't change the rule from direct to unrestricted.
+ self.change_rule_in_room(
+ room_id=self.direct_rooms[0],
+ new_rule=ACCESS_RULE_UNRESTRICTED,
+ expected_code=403,
+ )
+
def create_room(self, direct=False, rule=None, expected_code=200):
content = {
"is_direct": direct,
@@ -349,6 +395,20 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["rule"]
+ def change_rule_in_room(self, room_id, new_rule, expected_code=200):
+ data = {
+ "rule": new_rule,
+ }
+ request, channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, ACCESS_RULES_TYPE),
+ json.dumps(data),
+ access_token=self.tok,
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expected_code, channel.result)
+
def send_threepid_invite(self, address, room_id, expected_code=200):
params = {
"id_server": "testis",
|