summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/third_party_rules/access_rules.py181
1 files changed, 154 insertions, 27 deletions
diff --git a/synapse/third_party_rules/access_rules.py b/synapse/third_party_rules/access_rules.py
index 1f03138752..a8b3ed9458 100644
--- a/synapse/third_party_rules/access_rules.py
+++ b/synapse/third_party_rules/access_rules.py
@@ -28,9 +28,31 @@ ACCESS_RULE_DIRECT = "direct"
 
 
 class RoomAccessRules(object):
+    """Implementation of the ThirdPartyEventRules module API that allows federation admins
+    to define custom rules for specific events and actions.
+    Implements the custom behaviour for the "im.vector.room.access_rules" state event.
+
+    Takes a config in the format:
+
+    third_party_event_rules:
+        module: third_party_rules.RoomAccessRules
+        config:
+            # List of domains (server names) that can't be invited to rooms if the
+            # "restricted" rule is set. Defaults to an empty list.
+            domains_forbidden_when_restricted: []
+
+            # Identity server to use when checking the HS an email address belongs to
+            # using the /info endpoint. Required.
+            id_server: "vector.im"
+
+    Don't forget to consider if you can invite users from your own domain.
+    """
+
     def __init__(self, config, http_client):
         self.http_client = http_client
+
         self.id_server = config["id_server"]
+
         self.domains_forbidden_when_restricted = config.get(
             "domains_forbidden_when_restricted", [],
         )
@@ -43,34 +65,77 @@ class RoomAccessRules(object):
             raise ConfigError("No IS for event rules TchapEventRules")
 
     def on_create_room(self, requester, config, is_requester_admin):
+        """Implements synapse.events.ThirdPartyEventRules.on_create_room
+
+        Checks if a im.vector.room.access_rules event is being set during room creation.
+        If yes, make sure the event is correct. Otherwise, append an event with the
+        default rule to the initial state.
+        """
+        is_direct = config.get("is_direct")
+        rules_in_initial_state = False
+
+        # If there's a rules event in the initial state, check if it complies with the
+        # spec for im.vector.room.access_rules and fix it if not.
         for event in config.get("initial_state", []):
             if event["type"] == ACCESS_RULES_TYPE:
-                # If there's already a rules event in the initial state, check if it
-                # breaks the rules for "direct", and if not don't do anything else.
-                if (
-                    not config.get("is_direct")
-                    or event["content"]["rule"] != ACCESS_RULE_DIRECT
-                ):
-                    return
-
-        # Append an access rules event to be sent once every other event in initial_state
-        # has been sent. If "is_direct" exists and is set to True, the rule needs to be
-        # "direct", and "restricted" otherwise.
-        if config.get("is_direct"):
-            default_rule = ACCESS_RULE_DIRECT
-        else:
-            default_rule = ACCESS_RULE_RESTRICTED
+                rules_in_initial_state = True
+
+                rule = event["content"].get("rule")
+
+                # Make sure the event has a valid content.
+                if rule is None:
+                    event["content"] = {
+                        "rule": self._on_create_room_default_rule(is_direct)
+                    }
+
+                # Make sure the rule name is valid.
+                if not self._is_rule_name_valid(rule):
+                    event["content"]["rule"] = self._on_create_room_default_rule(
+                        is_direct,
+                    )
+
+                # Make sure the rule is "direct" if the room is a direct chat.
+                if is_direct and rule != ACCESS_RULE_DIRECT:
+                    event["content"]["rule"] = ACCESS_RULE_DIRECT
+
+                # Make sure the rule is not "direct" if the room isn't a direct chat.
+                if rule == ACCESS_RULE_DIRECT and not is_direct:
+                    event["content"]["rule"] = ACCESS_RULE_RESTRICTED
+
+        # If there's no rules event in the initial state, create one with the default
+        # setting.
+        if not rules_in_initial_state:
+            config["initial_state"].append({
+                "type": ACCESS_RULES_TYPE,
+                "state_key": "",
+                "content": {
+                    "rule": self._on_create_room_default_rule(is_direct),
+                }
+            })
 
-        config["initial_state"].append({
-            "type": ACCESS_RULES_TYPE,
-            "state_key": "",
-            "content": {
-                "rule": default_rule,
-            }
-        })
+    @staticmethod
+    def _on_create_room_default_rule(is_direct):
+        """Returns the default rule to set.
+
+        Args:
+            is_direct (bool): Is the room created with "is_direct" set to True.
+
+        Returns:
+            str, the name of the rule tu use as the default.
+        """
+        if is_direct:
+            return ACCESS_RULE_DIRECT
+        else:
+            return ACCESS_RULE_RESTRICTED
 
     @defer.inlineCallbacks
     def check_threepid_can_be_invited(self, medium, address, state_events):
+        """Implements synapse.events.ThirdPartyEventRules.check_threepid_can_be_invited
+
+        Check if a threepid can be invited to the room via a 3PID invite given the current
+        rules and the threepid's address, by retrieving the HS it's mapped to from the
+        configured identity server, and checking if we can invite users from it.
+        """
         rule = self._get_rule_from_state(state_events)
 
         if medium != "email":
@@ -105,6 +170,11 @@ class RoomAccessRules(object):
         defer.returnValue(True)
 
     def check_event_allowed(self, event, state_events):
+        """Implements synapse.events.ThirdPartyEventRules.check_event_allowed
+
+        Checks the event's type and the current rule and calls the right function to
+        determine whether the event can be allowed.
+        """
         # Special-case the access rules event.
         if event.type == ACCESS_RULES_TYPE:
             return self._on_rules_change(event, state_events)
@@ -125,14 +195,20 @@ class RoomAccessRules(object):
         return ret
 
     def _on_rules_change(self, event, state_events):
+        """Implement the checks and behaviour specified on allowing or forbidding a new
+        im.vector.room.access_rules event.
+
+        Args:
+            event (synapse.events.EventBase): The event to check.
+            state_events (dict[tuple[event type, state key], EventBase]): The state of the
+                room before the event was sent.
+        Returns:
+            bool, True if the event can be allowed, False otherwise.
+        """
         new_rule = event.content.get("rule")
 
         # Check for invalid values.
-        if (
-            new_rule != ACCESS_RULE_DIRECT
-            and new_rule != ACCESS_RULE_RESTRICTED
-            and new_rule != ACCESS_RULE_UNRESTRICTED
-        ):
+        if not self._is_rule_name_valid(new_rule):
             return False
 
         # Make sure we don't apply "direct" if the room has more than two members.
@@ -161,16 +237,37 @@ class RoomAccessRules(object):
         return False
 
     def _apply_restricted(self, event):
+        """Implements the checks and behaviour specified for the "restricted" rule.
+
+        Args:
+            event (synapse.events.EventBase): The event to check.
+        Returns:
+            bool, True if the event can be allowed, False otherwise.
+        """
         # "restricted" currently means that users can only invite users if their server is
         # included in a limited list of domains.
         invitee_domain = DomainRuleChecker._get_domain_from_id(event.state_key)
         return invitee_domain not in self.domains_forbidden_when_restricted
 
     def _apply_unrestricted(self):
+        """Implements the checks and behaviour specified for the "unrestricted" rule.
+
+        Returns:
+            bool, True if the event can be allowed, False otherwise.
+        """
         # "unrestricted" currently means that every event is allowed.
         return True
 
     def _apply_direct(self, event, state_events):
+        """Implements the checks and behaviour specified for the "direct" rule.
+
+        Args:
+            event (synapse.events.EventBase): The event to check.
+            state_events (dict[tuple[event type, state key], EventBase]): The state of the
+                room before the event was sent.
+        Returns:
+            bool, True if the event can be allowed, False otherwise.
+        """
         # "direct" currently means that no member is allowed apart from the two initial
         # members the room was created for (i.e. the room's creator and their first
         # invitee).
@@ -235,6 +332,14 @@ class RoomAccessRules(object):
 
     @staticmethod
     def _get_rule_from_state(state_events):
+        """Extract the rule to be applied from the given set of state events.
+
+        Args:
+            state_events (dict[tuple[event type, state key], EventBase]): The set of state
+                events
+        Returns:
+            str, the name of the rule (either "direct", "restricted" or "unrestricted")
+        """
         access_rules = state_events.get((ACCESS_RULES_TYPE, ""))
         if access_rules is None:
             rule = ACCESS_RULE_RESTRICTED
@@ -244,5 +349,27 @@ class RoomAccessRules(object):
 
     @staticmethod
     def _is_invite_from_threepid(invite, threepid_invite):
+        """Checks whether the given invite follows the given 3PID invite.
+
+        Args:
+             invite (EventBase): The m.room.member event with "invite" membership.
+             threepid_invite (EventBase): The m.room.third_party_invite event.
+        """
         token = invite.content.get("third_party_signed", {}).get("token", "")
         return token == threepid_invite.state_key
+
+    @staticmethod
+    def _is_rule_name_valid(rule):
+        """Returns whether the given rule name is within the allowed values ("direct",
+        "restricted" or "unrestricted").
+
+        Args:
+            rule (str): The name of the rule.
+        Returns:
+            bool, True if the name is valid, False otherwise.
+        """
+        return (
+            rule == ACCESS_RULE_DIRECT
+            or rule == ACCESS_RULE_RESTRICTED
+            or rule == ACCESS_RULE_UNRESTRICTED
+        )