diff options
Diffstat (limited to 'synapse/appservice/__init__.py')
-rw-r--r-- | synapse/appservice/__init__.py | 101 |
1 files changed, 61 insertions, 40 deletions
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index f9d3bd337d..8c9ff93b2c 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -11,10 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging import re from enum import Enum -from typing import TYPE_CHECKING, Iterable, List, Match, Optional +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Pattern + +import attr +from netaddr import IPSet from synapse.api.constants import EventTypes from synapse.events import EventBase @@ -33,6 +37,13 @@ class ApplicationServiceState(Enum): UP = "up" +@attr.s(slots=True, frozen=True, auto_attribs=True) +class Namespace: + exclusive: bool + group_id: Optional[str] + regex: Pattern[str] + + class ApplicationService: """Defines an application service. This definition is mostly what is provided to the /register AS API. @@ -50,17 +61,17 @@ class ApplicationService: def __init__( self, - token, - hostname, - id, - sender, - url=None, - namespaces=None, - hs_token=None, - protocols=None, - rate_limited=True, - ip_range_whitelist=None, - supports_ephemeral=False, + token: str, + hostname: str, + id: str, + sender: str, + url: Optional[str] = None, + namespaces: Optional[JsonDict] = None, + hs_token: Optional[str] = None, + protocols: Optional[Iterable[str]] = None, + rate_limited: bool = True, + ip_range_whitelist: Optional[IPSet] = None, + supports_ephemeral: bool = False, ): self.token = token self.url = ( @@ -85,27 +96,33 @@ class ApplicationService: self.rate_limited = rate_limited - def _check_namespaces(self, namespaces): + def _check_namespaces( + self, namespaces: Optional[JsonDict] + ) -> Dict[str, List[Namespace]]: # Sanity check that it is of the form: # { # users: [ {regex: "[A-z]+.*", exclusive: true}, ...], # aliases: [ {regex: "[A-z]+.*", exclusive: true}, ...], # rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...], # } - if not namespaces: + if namespaces is None: namespaces = {} + result: Dict[str, List[Namespace]] = {} + for ns in ApplicationService.NS_LIST: + result[ns] = [] + if ns not in namespaces: - namespaces[ns] = [] continue - if type(namespaces[ns]) != list: + if not isinstance(namespaces[ns], list): raise ValueError("Bad namespace value for '%s'" % ns) for regex_obj in namespaces[ns]: if not isinstance(regex_obj, dict): raise ValueError("Expected dict regex for ns '%s'" % ns) - if not isinstance(regex_obj.get("exclusive"), bool): + exclusive = regex_obj.get("exclusive") + if not isinstance(exclusive, bool): raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns) group_id = regex_obj.get("group_id") if group_id: @@ -126,22 +143,26 @@ class ApplicationService: ) regex = regex_obj.get("regex") - if isinstance(regex, str): - regex_obj["regex"] = re.compile(regex) # Pre-compile regex - else: + if not isinstance(regex, str): raise ValueError("Expected string for 'regex' in ns '%s'" % ns) - return namespaces - def _matches_regex(self, test_string: str, namespace_key: str) -> Optional[Match]: - for regex_obj in self.namespaces[namespace_key]: - if regex_obj["regex"].match(test_string): - return regex_obj + # Pre-compile regex. + result[ns].append(Namespace(exclusive, group_id, re.compile(regex))) + + return result + + def _matches_regex( + self, namespace_key: str, test_string: str + ) -> Optional[Namespace]: + for namespace in self.namespaces[namespace_key]: + if namespace.regex.match(test_string): + return namespace return None - def _is_exclusive(self, ns_key: str, test_string: str) -> bool: - regex_obj = self._matches_regex(test_string, ns_key) - if regex_obj: - return regex_obj["exclusive"] + def _is_exclusive(self, namespace_key: str, test_string: str) -> bool: + namespace = self._matches_regex(namespace_key, test_string) + if namespace: + return namespace.exclusive return False async def _matches_user( @@ -260,15 +281,15 @@ class ApplicationService: def is_interested_in_user(self, user_id: str) -> bool: return ( - bool(self._matches_regex(user_id, ApplicationService.NS_USERS)) + bool(self._matches_regex(ApplicationService.NS_USERS, user_id)) or user_id == self.sender ) def is_interested_in_alias(self, alias: str) -> bool: - return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES)) + return bool(self._matches_regex(ApplicationService.NS_ALIASES, alias)) def is_interested_in_room(self, room_id: str) -> bool: - return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS)) + return bool(self._matches_regex(ApplicationService.NS_ROOMS, room_id)) def is_exclusive_user(self, user_id: str) -> bool: return ( @@ -285,14 +306,14 @@ class ApplicationService: def is_exclusive_room(self, room_id: str) -> bool: return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) - def get_exclusive_user_regexes(self): + def get_exclusive_user_regexes(self) -> List[Pattern[str]]: """Get the list of regexes used to determine if a user is exclusively registered by the AS """ return [ - regex_obj["regex"] - for regex_obj in self.namespaces[ApplicationService.NS_USERS] - if regex_obj["exclusive"] + namespace.regex + for namespace in self.namespaces[ApplicationService.NS_USERS] + if namespace.exclusive ] def get_groups_for_user(self, user_id: str) -> Iterable[str]: @@ -305,15 +326,15 @@ class ApplicationService: An iterable that yields group_id strings. """ return ( - regex_obj["group_id"] - for regex_obj in self.namespaces[ApplicationService.NS_USERS] - if "group_id" in regex_obj and regex_obj["regex"].match(user_id) + namespace.group_id + for namespace in self.namespaces[ApplicationService.NS_USERS] + if namespace.group_id and namespace.regex.match(user_id) ) def is_rate_limited(self) -> bool: return self.rate_limited - def __str__(self): + def __str__(self) -> str: # copy dictionary and redact token fields so they don't get logged dict_copy = self.__dict__.copy() dict_copy["token"] = "<redacted>" |