diff --git a/synapse/third_party_rules/access_rules.py b/synapse/third_party_rules/access_rules.py
index 1f03138752..a8b3ed9458 100644
--- a/synapse/third_party_rules/access_rules.py
+++ b/synapse/third_party_rules/access_rules.py
@@ -28,9 +28,31 @@ ACCESS_RULE_DIRECT = "direct"
class RoomAccessRules(object):
+ """Implementation of the ThirdPartyEventRules module API that allows federation admins
+ to define custom rules for specific events and actions.
+ Implements the custom behaviour for the "im.vector.room.access_rules" state event.
+
+ Takes a config in the format:
+
+ third_party_event_rules:
+ module: third_party_rules.RoomAccessRules
+ config:
+ # List of domains (server names) that can't be invited to rooms if the
+ # "restricted" rule is set. Defaults to an empty list.
+ domains_forbidden_when_restricted: []
+
+ # Identity server to use when checking the HS an email address belongs to
+ # using the /info endpoint. Required.
+ id_server: "vector.im"
+
+ Don't forget to consider if you can invite users from your own domain.
+ """
+
def __init__(self, config, http_client):
self.http_client = http_client
+
self.id_server = config["id_server"]
+
self.domains_forbidden_when_restricted = config.get(
"domains_forbidden_when_restricted", [],
)
@@ -43,34 +65,77 @@ class RoomAccessRules(object):
raise ConfigError("No IS for event rules TchapEventRules")
def on_create_room(self, requester, config, is_requester_admin):
+ """Implements synapse.events.ThirdPartyEventRules.on_create_room
+
+ Checks if a im.vector.room.access_rules event is being set during room creation.
+ If yes, make sure the event is correct. Otherwise, append an event with the
+ default rule to the initial state.
+ """
+ is_direct = config.get("is_direct")
+ rules_in_initial_state = False
+
+ # If there's a rules event in the initial state, check if it complies with the
+ # spec for im.vector.room.access_rules and fix it if not.
for event in config.get("initial_state", []):
if event["type"] == ACCESS_RULES_TYPE:
- # If there's already a rules event in the initial state, check if it
- # breaks the rules for "direct", and if not don't do anything else.
- if (
- not config.get("is_direct")
- or event["content"]["rule"] != ACCESS_RULE_DIRECT
- ):
- return
-
- # Append an access rules event to be sent once every other event in initial_state
- # has been sent. If "is_direct" exists and is set to True, the rule needs to be
- # "direct", and "restricted" otherwise.
- if config.get("is_direct"):
- default_rule = ACCESS_RULE_DIRECT
- else:
- default_rule = ACCESS_RULE_RESTRICTED
+ rules_in_initial_state = True
+
+ rule = event["content"].get("rule")
+
+ # Make sure the event has a valid content.
+ if rule is None:
+ event["content"] = {
+ "rule": self._on_create_room_default_rule(is_direct)
+ }
+
+ # Make sure the rule name is valid.
+ if not self._is_rule_name_valid(rule):
+ event["content"]["rule"] = self._on_create_room_default_rule(
+ is_direct,
+ )
+
+ # Make sure the rule is "direct" if the room is a direct chat.
+ if is_direct and rule != ACCESS_RULE_DIRECT:
+ event["content"]["rule"] = ACCESS_RULE_DIRECT
+
+ # Make sure the rule is not "direct" if the room isn't a direct chat.
+ if rule == ACCESS_RULE_DIRECT and not is_direct:
+ event["content"]["rule"] = ACCESS_RULE_RESTRICTED
+
+ # If there's no rules event in the initial state, create one with the default
+ # setting.
+ if not rules_in_initial_state:
+ config["initial_state"].append({
+ "type": ACCESS_RULES_TYPE,
+ "state_key": "",
+ "content": {
+ "rule": self._on_create_room_default_rule(is_direct),
+ }
+ })
- config["initial_state"].append({
- "type": ACCESS_RULES_TYPE,
- "state_key": "",
- "content": {
- "rule": default_rule,
- }
- })
+ @staticmethod
+ def _on_create_room_default_rule(is_direct):
+ """Returns the default rule to set.
+
+ Args:
+ is_direct (bool): Is the room created with "is_direct" set to True.
+
+ Returns:
+ str, the name of the rule tu use as the default.
+ """
+ if is_direct:
+ return ACCESS_RULE_DIRECT
+ else:
+ return ACCESS_RULE_RESTRICTED
@defer.inlineCallbacks
def check_threepid_can_be_invited(self, medium, address, state_events):
+ """Implements synapse.events.ThirdPartyEventRules.check_threepid_can_be_invited
+
+ Check if a threepid can be invited to the room via a 3PID invite given the current
+ rules and the threepid's address, by retrieving the HS it's mapped to from the
+ configured identity server, and checking if we can invite users from it.
+ """
rule = self._get_rule_from_state(state_events)
if medium != "email":
@@ -105,6 +170,11 @@ class RoomAccessRules(object):
defer.returnValue(True)
def check_event_allowed(self, event, state_events):
+ """Implements synapse.events.ThirdPartyEventRules.check_event_allowed
+
+ Checks the event's type and the current rule and calls the right function to
+ determine whether the event can be allowed.
+ """
# Special-case the access rules event.
if event.type == ACCESS_RULES_TYPE:
return self._on_rules_change(event, state_events)
@@ -125,14 +195,20 @@ class RoomAccessRules(object):
return ret
def _on_rules_change(self, event, state_events):
+ """Implement the checks and behaviour specified on allowing or forbidding a new
+ im.vector.room.access_rules event.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ state_events (dict[tuple[event type, state key], EventBase]): The state of the
+ room before the event was sent.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
new_rule = event.content.get("rule")
# Check for invalid values.
- if (
- new_rule != ACCESS_RULE_DIRECT
- and new_rule != ACCESS_RULE_RESTRICTED
- and new_rule != ACCESS_RULE_UNRESTRICTED
- ):
+ if not self._is_rule_name_valid(new_rule):
return False
# Make sure we don't apply "direct" if the room has more than two members.
@@ -161,16 +237,37 @@ class RoomAccessRules(object):
return False
def _apply_restricted(self, event):
+ """Implements the checks and behaviour specified for the "restricted" rule.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
# "restricted" currently means that users can only invite users if their server is
# included in a limited list of domains.
invitee_domain = DomainRuleChecker._get_domain_from_id(event.state_key)
return invitee_domain not in self.domains_forbidden_when_restricted
def _apply_unrestricted(self):
+ """Implements the checks and behaviour specified for the "unrestricted" rule.
+
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
# "unrestricted" currently means that every event is allowed.
return True
def _apply_direct(self, event, state_events):
+ """Implements the checks and behaviour specified for the "direct" rule.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ state_events (dict[tuple[event type, state key], EventBase]): The state of the
+ room before the event was sent.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
# "direct" currently means that no member is allowed apart from the two initial
# members the room was created for (i.e. the room's creator and their first
# invitee).
@@ -235,6 +332,14 @@ class RoomAccessRules(object):
@staticmethod
def _get_rule_from_state(state_events):
+ """Extract the rule to be applied from the given set of state events.
+
+ Args:
+ state_events (dict[tuple[event type, state key], EventBase]): The set of state
+ events
+ Returns:
+ str, the name of the rule (either "direct", "restricted" or "unrestricted")
+ """
access_rules = state_events.get((ACCESS_RULES_TYPE, ""))
if access_rules is None:
rule = ACCESS_RULE_RESTRICTED
@@ -244,5 +349,27 @@ class RoomAccessRules(object):
@staticmethod
def _is_invite_from_threepid(invite, threepid_invite):
+ """Checks whether the given invite follows the given 3PID invite.
+
+ Args:
+ invite (EventBase): The m.room.member event with "invite" membership.
+ threepid_invite (EventBase): The m.room.third_party_invite event.
+ """
token = invite.content.get("third_party_signed", {}).get("token", "")
return token == threepid_invite.state_key
+
+ @staticmethod
+ def _is_rule_name_valid(rule):
+ """Returns whether the given rule name is within the allowed values ("direct",
+ "restricted" or "unrestricted").
+
+ Args:
+ rule (str): The name of the rule.
+ Returns:
+ bool, True if the name is valid, False otherwise.
+ """
+ return (
+ rule == ACCESS_RULE_DIRECT
+ or rule == ACCESS_RULE_RESTRICTED
+ or rule == ACCESS_RULE_UNRESTRICTED
+ )
|