diff options
-rw-r--r-- | synapse/third_party_rules/access_rules.py | 181 |
1 files changed, 154 insertions, 27 deletions
diff --git a/synapse/third_party_rules/access_rules.py b/synapse/third_party_rules/access_rules.py index 1f03138752..a8b3ed9458 100644 --- a/synapse/third_party_rules/access_rules.py +++ b/synapse/third_party_rules/access_rules.py @@ -28,9 +28,31 @@ ACCESS_RULE_DIRECT = "direct" class RoomAccessRules(object): + """Implementation of the ThirdPartyEventRules module API that allows federation admins + to define custom rules for specific events and actions. + Implements the custom behaviour for the "im.vector.room.access_rules" state event. + + Takes a config in the format: + + third_party_event_rules: + module: third_party_rules.RoomAccessRules + config: + # List of domains (server names) that can't be invited to rooms if the + # "restricted" rule is set. Defaults to an empty list. + domains_forbidden_when_restricted: [] + + # Identity server to use when checking the HS an email address belongs to + # using the /info endpoint. Required. + id_server: "vector.im" + + Don't forget to consider if you can invite users from your own domain. + """ + def __init__(self, config, http_client): self.http_client = http_client + self.id_server = config["id_server"] + self.domains_forbidden_when_restricted = config.get( "domains_forbidden_when_restricted", [], ) @@ -43,34 +65,77 @@ class RoomAccessRules(object): raise ConfigError("No IS for event rules TchapEventRules") def on_create_room(self, requester, config, is_requester_admin): + """Implements synapse.events.ThirdPartyEventRules.on_create_room + + Checks if a im.vector.room.access_rules event is being set during room creation. + If yes, make sure the event is correct. Otherwise, append an event with the + default rule to the initial state. + """ + is_direct = config.get("is_direct") + rules_in_initial_state = False + + # If there's a rules event in the initial state, check if it complies with the + # spec for im.vector.room.access_rules and fix it if not. for event in config.get("initial_state", []): if event["type"] == ACCESS_RULES_TYPE: - # If there's already a rules event in the initial state, check if it - # breaks the rules for "direct", and if not don't do anything else. - if ( - not config.get("is_direct") - or event["content"]["rule"] != ACCESS_RULE_DIRECT - ): - return - - # Append an access rules event to be sent once every other event in initial_state - # has been sent. If "is_direct" exists and is set to True, the rule needs to be - # "direct", and "restricted" otherwise. - if config.get("is_direct"): - default_rule = ACCESS_RULE_DIRECT - else: - default_rule = ACCESS_RULE_RESTRICTED + rules_in_initial_state = True + + rule = event["content"].get("rule") + + # Make sure the event has a valid content. + if rule is None: + event["content"] = { + "rule": self._on_create_room_default_rule(is_direct) + } + + # Make sure the rule name is valid. + if not self._is_rule_name_valid(rule): + event["content"]["rule"] = self._on_create_room_default_rule( + is_direct, + ) + + # Make sure the rule is "direct" if the room is a direct chat. + if is_direct and rule != ACCESS_RULE_DIRECT: + event["content"]["rule"] = ACCESS_RULE_DIRECT + + # Make sure the rule is not "direct" if the room isn't a direct chat. + if rule == ACCESS_RULE_DIRECT and not is_direct: + event["content"]["rule"] = ACCESS_RULE_RESTRICTED + + # If there's no rules event in the initial state, create one with the default + # setting. + if not rules_in_initial_state: + config["initial_state"].append({ + "type": ACCESS_RULES_TYPE, + "state_key": "", + "content": { + "rule": self._on_create_room_default_rule(is_direct), + } + }) - config["initial_state"].append({ - "type": ACCESS_RULES_TYPE, - "state_key": "", - "content": { - "rule": default_rule, - } - }) + @staticmethod + def _on_create_room_default_rule(is_direct): + """Returns the default rule to set. + + Args: + is_direct (bool): Is the room created with "is_direct" set to True. + + Returns: + str, the name of the rule tu use as the default. + """ + if is_direct: + return ACCESS_RULE_DIRECT + else: + return ACCESS_RULE_RESTRICTED @defer.inlineCallbacks def check_threepid_can_be_invited(self, medium, address, state_events): + """Implements synapse.events.ThirdPartyEventRules.check_threepid_can_be_invited + + Check if a threepid can be invited to the room via a 3PID invite given the current + rules and the threepid's address, by retrieving the HS it's mapped to from the + configured identity server, and checking if we can invite users from it. + """ rule = self._get_rule_from_state(state_events) if medium != "email": @@ -105,6 +170,11 @@ class RoomAccessRules(object): defer.returnValue(True) def check_event_allowed(self, event, state_events): + """Implements synapse.events.ThirdPartyEventRules.check_event_allowed + + Checks the event's type and the current rule and calls the right function to + determine whether the event can be allowed. + """ # Special-case the access rules event. if event.type == ACCESS_RULES_TYPE: return self._on_rules_change(event, state_events) @@ -125,14 +195,20 @@ class RoomAccessRules(object): return ret def _on_rules_change(self, event, state_events): + """Implement the checks and behaviour specified on allowing or forbidding a new + im.vector.room.access_rules event. + + Args: + event (synapse.events.EventBase): The event to check. + state_events (dict[tuple[event type, state key], EventBase]): The state of the + room before the event was sent. + Returns: + bool, True if the event can be allowed, False otherwise. + """ new_rule = event.content.get("rule") # Check for invalid values. - if ( - new_rule != ACCESS_RULE_DIRECT - and new_rule != ACCESS_RULE_RESTRICTED - and new_rule != ACCESS_RULE_UNRESTRICTED - ): + if not self._is_rule_name_valid(new_rule): return False # Make sure we don't apply "direct" if the room has more than two members. @@ -161,16 +237,37 @@ class RoomAccessRules(object): return False def _apply_restricted(self, event): + """Implements the checks and behaviour specified for the "restricted" rule. + + Args: + event (synapse.events.EventBase): The event to check. + Returns: + bool, True if the event can be allowed, False otherwise. + """ # "restricted" currently means that users can only invite users if their server is # included in a limited list of domains. invitee_domain = DomainRuleChecker._get_domain_from_id(event.state_key) return invitee_domain not in self.domains_forbidden_when_restricted def _apply_unrestricted(self): + """Implements the checks and behaviour specified for the "unrestricted" rule. + + Returns: + bool, True if the event can be allowed, False otherwise. + """ # "unrestricted" currently means that every event is allowed. return True def _apply_direct(self, event, state_events): + """Implements the checks and behaviour specified for the "direct" rule. + + Args: + event (synapse.events.EventBase): The event to check. + state_events (dict[tuple[event type, state key], EventBase]): The state of the + room before the event was sent. + Returns: + bool, True if the event can be allowed, False otherwise. + """ # "direct" currently means that no member is allowed apart from the two initial # members the room was created for (i.e. the room's creator and their first # invitee). @@ -235,6 +332,14 @@ class RoomAccessRules(object): @staticmethod def _get_rule_from_state(state_events): + """Extract the rule to be applied from the given set of state events. + + Args: + state_events (dict[tuple[event type, state key], EventBase]): The set of state + events + Returns: + str, the name of the rule (either "direct", "restricted" or "unrestricted") + """ access_rules = state_events.get((ACCESS_RULES_TYPE, "")) if access_rules is None: rule = ACCESS_RULE_RESTRICTED @@ -244,5 +349,27 @@ class RoomAccessRules(object): @staticmethod def _is_invite_from_threepid(invite, threepid_invite): + """Checks whether the given invite follows the given 3PID invite. + + Args: + invite (EventBase): The m.room.member event with "invite" membership. + threepid_invite (EventBase): The m.room.third_party_invite event. + """ token = invite.content.get("third_party_signed", {}).get("token", "") return token == threepid_invite.state_key + + @staticmethod + def _is_rule_name_valid(rule): + """Returns whether the given rule name is within the allowed values ("direct", + "restricted" or "unrestricted"). + + Args: + rule (str): The name of the rule. + Returns: + bool, True if the name is valid, False otherwise. + """ + return ( + rule == ACCESS_RULE_DIRECT + or rule == ACCESS_RULE_RESTRICTED + or rule == ACCESS_RULE_UNRESTRICTED + ) |