summary refs log tree commit diff
path: root/synapse/appservice/__init__.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-12-14 12:02:46 -0500
committerGitHub <noreply@github.com>2021-12-14 17:02:46 +0000
commit2519beaad25d2c824129c26eab10962773e643a6 (patch)
treee4a1eb70f5d960c7edf2486a03b27bc958edde55 /synapse/appservice/__init__.py
parentMerge branch 'master' into develop (diff)
downloadsynapse-2519beaad25d2c824129c26eab10962773e643a6.tar.xz
Add missing type hints to `synapse.appservice` (#11360)
Diffstat (limited to 'synapse/appservice/__init__.py')
-rw-r--r--synapse/appservice/__init__.py101
1 files changed, 61 insertions, 40 deletions
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index f9d3bd337d..8c9ff93b2c 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -11,10 +11,14 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import logging
 import re
 from enum import Enum
-from typing import TYPE_CHECKING, Iterable, List, Match, Optional
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Pattern
+
+import attr
+from netaddr import IPSet
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
@@ -33,6 +37,13 @@ class ApplicationServiceState(Enum):
     UP = "up"
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class Namespace:
+    exclusive: bool
+    group_id: Optional[str]
+    regex: Pattern[str]
+
+
 class ApplicationService:
     """Defines an application service. This definition is mostly what is
     provided to the /register AS API.
@@ -50,17 +61,17 @@ class ApplicationService:
 
     def __init__(
         self,
-        token,
-        hostname,
-        id,
-        sender,
-        url=None,
-        namespaces=None,
-        hs_token=None,
-        protocols=None,
-        rate_limited=True,
-        ip_range_whitelist=None,
-        supports_ephemeral=False,
+        token: str,
+        hostname: str,
+        id: str,
+        sender: str,
+        url: Optional[str] = None,
+        namespaces: Optional[JsonDict] = None,
+        hs_token: Optional[str] = None,
+        protocols: Optional[Iterable[str]] = None,
+        rate_limited: bool = True,
+        ip_range_whitelist: Optional[IPSet] = None,
+        supports_ephemeral: bool = False,
     ):
         self.token = token
         self.url = (
@@ -85,27 +96,33 @@ class ApplicationService:
 
         self.rate_limited = rate_limited
 
-    def _check_namespaces(self, namespaces):
+    def _check_namespaces(
+        self, namespaces: Optional[JsonDict]
+    ) -> Dict[str, List[Namespace]]:
         # Sanity check that it is of the form:
         # {
         #   users: [ {regex: "[A-z]+.*", exclusive: true}, ...],
         #   aliases: [ {regex: "[A-z]+.*", exclusive: true}, ...],
         #   rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...],
         # }
-        if not namespaces:
+        if namespaces is None:
             namespaces = {}
 
+        result: Dict[str, List[Namespace]] = {}
+
         for ns in ApplicationService.NS_LIST:
+            result[ns] = []
+
             if ns not in namespaces:
-                namespaces[ns] = []
                 continue
 
-            if type(namespaces[ns]) != list:
+            if not isinstance(namespaces[ns], list):
                 raise ValueError("Bad namespace value for '%s'" % ns)
             for regex_obj in namespaces[ns]:
                 if not isinstance(regex_obj, dict):
                     raise ValueError("Expected dict regex for ns '%s'" % ns)
-                if not isinstance(regex_obj.get("exclusive"), bool):
+                exclusive = regex_obj.get("exclusive")
+                if not isinstance(exclusive, bool):
                     raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns)
                 group_id = regex_obj.get("group_id")
                 if group_id:
@@ -126,22 +143,26 @@ class ApplicationService:
                         )
 
                 regex = regex_obj.get("regex")
-                if isinstance(regex, str):
-                    regex_obj["regex"] = re.compile(regex)  # Pre-compile regex
-                else:
+                if not isinstance(regex, str):
                     raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
-        return namespaces
 
-    def _matches_regex(self, test_string: str, namespace_key: str) -> Optional[Match]:
-        for regex_obj in self.namespaces[namespace_key]:
-            if regex_obj["regex"].match(test_string):
-                return regex_obj
+                # Pre-compile regex.
+                result[ns].append(Namespace(exclusive, group_id, re.compile(regex)))
+
+        return result
+
+    def _matches_regex(
+        self, namespace_key: str, test_string: str
+    ) -> Optional[Namespace]:
+        for namespace in self.namespaces[namespace_key]:
+            if namespace.regex.match(test_string):
+                return namespace
         return None
 
-    def _is_exclusive(self, ns_key: str, test_string: str) -> bool:
-        regex_obj = self._matches_regex(test_string, ns_key)
-        if regex_obj:
-            return regex_obj["exclusive"]
+    def _is_exclusive(self, namespace_key: str, test_string: str) -> bool:
+        namespace = self._matches_regex(namespace_key, test_string)
+        if namespace:
+            return namespace.exclusive
         return False
 
     async def _matches_user(
@@ -260,15 +281,15 @@ class ApplicationService:
 
     def is_interested_in_user(self, user_id: str) -> bool:
         return (
-            bool(self._matches_regex(user_id, ApplicationService.NS_USERS))
+            bool(self._matches_regex(ApplicationService.NS_USERS, user_id))
             or user_id == self.sender
         )
 
     def is_interested_in_alias(self, alias: str) -> bool:
-        return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
+        return bool(self._matches_regex(ApplicationService.NS_ALIASES, alias))
 
     def is_interested_in_room(self, room_id: str) -> bool:
-        return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
+        return bool(self._matches_regex(ApplicationService.NS_ROOMS, room_id))
 
     def is_exclusive_user(self, user_id: str) -> bool:
         return (
@@ -285,14 +306,14 @@ class ApplicationService:
     def is_exclusive_room(self, room_id: str) -> bool:
         return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
 
-    def get_exclusive_user_regexes(self):
+    def get_exclusive_user_regexes(self) -> List[Pattern[str]]:
         """Get the list of regexes used to determine if a user is exclusively
         registered by the AS
         """
         return [
-            regex_obj["regex"]
-            for regex_obj in self.namespaces[ApplicationService.NS_USERS]
-            if regex_obj["exclusive"]
+            namespace.regex
+            for namespace in self.namespaces[ApplicationService.NS_USERS]
+            if namespace.exclusive
         ]
 
     def get_groups_for_user(self, user_id: str) -> Iterable[str]:
@@ -305,15 +326,15 @@ class ApplicationService:
             An iterable that yields group_id strings.
         """
         return (
-            regex_obj["group_id"]
-            for regex_obj in self.namespaces[ApplicationService.NS_USERS]
-            if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
+            namespace.group_id
+            for namespace in self.namespaces[ApplicationService.NS_USERS]
+            if namespace.group_id and namespace.regex.match(user_id)
         )
 
     def is_rate_limited(self) -> bool:
         return self.rate_limited
 
-    def __str__(self):
+    def __str__(self) -> str:
         # copy dictionary and redact token fields so they don't get logged
         dict_copy = self.__dict__.copy()
         dict_copy["token"] = "<redacted>"