diff --git a/synapse/third_party_rules/access_rules.py b/synapse/third_party_rules/access_rules.py
index 2cb77e8115..a7fd9eab2e 100644
--- a/synapse/third_party_rules/access_rules.py
+++ b/synapse/third_party_rules/access_rules.py
@@ -232,6 +232,9 @@ class RoomAccessRules(object):
if event.type == EventTypes.Member or event.type == EventTypes.ThirdPartyInvite:
return self._on_membership_or_invite(event, rule, state_events)
+ if event.type == EventTypes.JoinRules:
+ return self._on_join_rule_change(event, rule)
+
return True
def _on_rules_change(self, event, state_events):
@@ -251,6 +254,12 @@ class RoomAccessRules(object):
if new_rule not in VALID_ACCESS_RULES:
return False
+ # We must not allow rooms with the "public" join rule to be given any other access
+ # rule than "restricted".
+ join_rule = self._get_join_rule_from_state(state_events)
+ if join_rule == JoinRules.PUBLIC and new_rule != ACCESS_RULE_RESTRICTED:
+ return False
+
# Make sure we don't apply "direct" if the room has more than two members.
if new_rule == ACCESS_RULE_DIRECT:
existing_members, threepid_tokens = self._get_members_and_tokens_from_state(
@@ -400,7 +409,6 @@ class RoomAccessRules(object):
access_rule (str): The access rule in place in this room.
Returns:
bool, True if the event can be allowed, False otherwise.
-
"""
# Check if we need to apply the restrictions with the current rule.
if access_rule not in RULES_WITH_RESTRICTED_POWER_LEVELS:
@@ -424,6 +432,22 @@ class RoomAccessRules(object):
return True
+ def _on_join_rule_change(self, event, rule):
+ """Check whether a join rule change is allowed. A join rule change is always
+ allowed unless the new join rule is "public" and the current access rule isn't
+ "restricted".
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ if event.content.get('join_rule') == JoinRules.PUBLIC:
+ return rule == ACCESS_RULE_RESTRICTED
+
+ return True
+
@staticmethod
def _get_rule_from_state(state_events):
"""Extract the rule to be applied from the given set of state events.
@@ -442,6 +466,21 @@ class RoomAccessRules(object):
return rule
@staticmethod
+ def _get_join_rule_from_state(state_events):
+ """Extract the room's join rule from the given set of state events.
+
+ Args:
+ state_events (dict[tuple[event type, state key], EventBase]): The set of state
+ events.
+ Returns:
+ str, the name of the join rule (either "public", or "invite")
+ """
+ join_rule_event = state_events.get((EventTypes.JoinRules, ""))
+ if join_rule_event is None:
+ return ""
+ return join_rule_event.content.get("join_rule")
+
+ @staticmethod
def _get_members_and_tokens_from_state(state_events):
"""Retrieves from a list of state events the list of users that have a
m.room.member event in the room, and the tokens of 3PID invites in the room.
|