summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11360.misc1
-rw-r--r--mypy.ini3
-rw-r--r--synapse/appservice/__init__.py101
-rw-r--r--synapse/appservice/api.py48
-rw-r--r--synapse/appservice/scheduler.py74
-rw-r--r--synapse/config/appservice.py3
-rw-r--r--tests/appservice/test_appservice.py11
7 files changed, 148 insertions, 93 deletions
diff --git a/changelog.d/11360.misc b/changelog.d/11360.misc
new file mode 100644
index 0000000000..43e25720c5
--- /dev/null
+++ b/changelog.d/11360.misc
@@ -0,0 +1 @@
+Add type hints to `synapse.appservice`.
diff --git a/mypy.ini b/mypy.ini
index 9aeeca2bb2..4551302c82 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -143,6 +143,9 @@ disallow_untyped_defs = True
 [mypy-synapse.app.*]
 disallow_untyped_defs = True
 
+[mypy-synapse.appservice.*]
+disallow_untyped_defs = True
+
 [mypy-synapse.config._base]
 disallow_untyped_defs = True
 
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index f9d3bd337d..8c9ff93b2c 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -11,10 +11,14 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import logging
 import re
 from enum import Enum
-from typing import TYPE_CHECKING, Iterable, List, Match, Optional
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Pattern
+
+import attr
+from netaddr import IPSet
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
@@ -33,6 +37,13 @@ class ApplicationServiceState(Enum):
     UP = "up"
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class Namespace:
+    exclusive: bool
+    group_id: Optional[str]
+    regex: Pattern[str]
+
+
 class ApplicationService:
     """Defines an application service. This definition is mostly what is
     provided to the /register AS API.
@@ -50,17 +61,17 @@ class ApplicationService:
 
     def __init__(
         self,
-        token,
-        hostname,
-        id,
-        sender,
-        url=None,
-        namespaces=None,
-        hs_token=None,
-        protocols=None,
-        rate_limited=True,
-        ip_range_whitelist=None,
-        supports_ephemeral=False,
+        token: str,
+        hostname: str,
+        id: str,
+        sender: str,
+        url: Optional[str] = None,
+        namespaces: Optional[JsonDict] = None,
+        hs_token: Optional[str] = None,
+        protocols: Optional[Iterable[str]] = None,
+        rate_limited: bool = True,
+        ip_range_whitelist: Optional[IPSet] = None,
+        supports_ephemeral: bool = False,
     ):
         self.token = token
         self.url = (
@@ -85,27 +96,33 @@ class ApplicationService:
 
         self.rate_limited = rate_limited
 
-    def _check_namespaces(self, namespaces):
+    def _check_namespaces(
+        self, namespaces: Optional[JsonDict]
+    ) -> Dict[str, List[Namespace]]:
         # Sanity check that it is of the form:
         # {
         #   users: [ {regex: "[A-z]+.*", exclusive: true}, ...],
         #   aliases: [ {regex: "[A-z]+.*", exclusive: true}, ...],
         #   rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...],
         # }
-        if not namespaces:
+        if namespaces is None:
             namespaces = {}
 
+        result: Dict[str, List[Namespace]] = {}
+
         for ns in ApplicationService.NS_LIST:
+            result[ns] = []
+
             if ns not in namespaces:
-                namespaces[ns] = []
                 continue
 
-            if type(namespaces[ns]) != list:
+            if not isinstance(namespaces[ns], list):
                 raise ValueError("Bad namespace value for '%s'" % ns)
             for regex_obj in namespaces[ns]:
                 if not isinstance(regex_obj, dict):
                     raise ValueError("Expected dict regex for ns '%s'" % ns)
-                if not isinstance(regex_obj.get("exclusive"), bool):
+                exclusive = regex_obj.get("exclusive")
+                if not isinstance(exclusive, bool):
                     raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns)
                 group_id = regex_obj.get("group_id")
                 if group_id:
@@ -126,22 +143,26 @@ class ApplicationService:
                         )
 
                 regex = regex_obj.get("regex")
-                if isinstance(regex, str):
-                    regex_obj["regex"] = re.compile(regex)  # Pre-compile regex
-                else:
+                if not isinstance(regex, str):
                     raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
-        return namespaces
 
-    def _matches_regex(self, test_string: str, namespace_key: str) -> Optional[Match]:
-        for regex_obj in self.namespaces[namespace_key]:
-            if regex_obj["regex"].match(test_string):
-                return regex_obj
+                # Pre-compile regex.
+                result[ns].append(Namespace(exclusive, group_id, re.compile(regex)))
+
+        return result
+
+    def _matches_regex(
+        self, namespace_key: str, test_string: str
+    ) -> Optional[Namespace]:
+        for namespace in self.namespaces[namespace_key]:
+            if namespace.regex.match(test_string):
+                return namespace
         return None
 
-    def _is_exclusive(self, ns_key: str, test_string: str) -> bool:
-        regex_obj = self._matches_regex(test_string, ns_key)
-        if regex_obj:
-            return regex_obj["exclusive"]
+    def _is_exclusive(self, namespace_key: str, test_string: str) -> bool:
+        namespace = self._matches_regex(namespace_key, test_string)
+        if namespace:
+            return namespace.exclusive
         return False
 
     async def _matches_user(
@@ -260,15 +281,15 @@ class ApplicationService:
 
     def is_interested_in_user(self, user_id: str) -> bool:
         return (
-            bool(self._matches_regex(user_id, ApplicationService.NS_USERS))
+            bool(self._matches_regex(ApplicationService.NS_USERS, user_id))
             or user_id == self.sender
         )
 
     def is_interested_in_alias(self, alias: str) -> bool:
-        return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
+        return bool(self._matches_regex(ApplicationService.NS_ALIASES, alias))
 
     def is_interested_in_room(self, room_id: str) -> bool:
-        return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
+        return bool(self._matches_regex(ApplicationService.NS_ROOMS, room_id))
 
     def is_exclusive_user(self, user_id: str) -> bool:
         return (
@@ -285,14 +306,14 @@ class ApplicationService:
     def is_exclusive_room(self, room_id: str) -> bool:
         return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
 
-    def get_exclusive_user_regexes(self):
+    def get_exclusive_user_regexes(self) -> List[Pattern[str]]:
         """Get the list of regexes used to determine if a user is exclusively
         registered by the AS
         """
         return [
-            regex_obj["regex"]
-            for regex_obj in self.namespaces[ApplicationService.NS_USERS]
-            if regex_obj["exclusive"]
+            namespace.regex
+            for namespace in self.namespaces[ApplicationService.NS_USERS]
+            if namespace.exclusive
         ]
 
     def get_groups_for_user(self, user_id: str) -> Iterable[str]:
@@ -305,15 +326,15 @@ class ApplicationService:
             An iterable that yields group_id strings.
         """
         return (
-            regex_obj["group_id"]
-            for regex_obj in self.namespaces[ApplicationService.NS_USERS]
-            if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
+            namespace.group_id
+            for namespace in self.namespaces[ApplicationService.NS_USERS]
+            if namespace.group_id and namespace.regex.match(user_id)
         )
 
     def is_rate_limited(self) -> bool:
         return self.rate_limited
 
-    def __str__(self):
+    def __str__(self) -> str:
         # copy dictionary and redact token fields so they don't get logged
         dict_copy = self.__dict__.copy()
         dict_copy["token"] = "<redacted>"
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index f51b636417..def4424af0 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -12,8 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import urllib
-from typing import TYPE_CHECKING, List, Optional, Tuple
+import urllib.parse
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
 
 from prometheus_client import Counter
 
@@ -53,7 +53,7 @@ HOUR_IN_MS = 60 * 60 * 1000
 APP_SERVICE_PREFIX = "/_matrix/app/unstable"
 
 
-def _is_valid_3pe_metadata(info):
+def _is_valid_3pe_metadata(info: JsonDict) -> bool:
     if "instances" not in info:
         return False
     if not isinstance(info["instances"], list):
@@ -61,7 +61,7 @@ def _is_valid_3pe_metadata(info):
     return True
 
 
-def _is_valid_3pe_result(r, field):
+def _is_valid_3pe_result(r: JsonDict, field: str) -> bool:
     if not isinstance(r, dict):
         return False
 
@@ -93,9 +93,13 @@ class ApplicationServiceApi(SimpleHttpClient):
             hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
         )
 
-    async def query_user(self, service, user_id):
+    async def query_user(self, service: "ApplicationService", user_id: str) -> bool:
         if service.url is None:
             return False
+
+        # This is required by the configuration.
+        assert service.hs_token is not None
+
         uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
         try:
             response = await self.get_json(uri, {"access_token": service.hs_token})
@@ -109,9 +113,13 @@ class ApplicationServiceApi(SimpleHttpClient):
             logger.warning("query_user to %s threw exception %s", uri, ex)
         return False
 
-    async def query_alias(self, service, alias):
+    async def query_alias(self, service: "ApplicationService", alias: str) -> bool:
         if service.url is None:
             return False
+
+        # This is required by the configuration.
+        assert service.hs_token is not None
+
         uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
         try:
             response = await self.get_json(uri, {"access_token": service.hs_token})
@@ -125,7 +133,13 @@ class ApplicationServiceApi(SimpleHttpClient):
             logger.warning("query_alias to %s threw exception %s", uri, ex)
         return False
 
-    async def query_3pe(self, service, kind, protocol, fields):
+    async def query_3pe(
+        self,
+        service: "ApplicationService",
+        kind: str,
+        protocol: str,
+        fields: Dict[bytes, List[bytes]],
+    ) -> List[JsonDict]:
         if kind == ThirdPartyEntityKind.USER:
             required_field = "userid"
         elif kind == ThirdPartyEntityKind.LOCATION:
@@ -205,11 +219,14 @@ class ApplicationServiceApi(SimpleHttpClient):
         events: List[EventBase],
         ephemeral: List[JsonDict],
         txn_id: Optional[int] = None,
-    ):
+    ) -> bool:
         if service.url is None:
             return True
 
-        events = self._serialize(service, events)
+        # This is required by the configuration.
+        assert service.hs_token is not None
+
+        serialized_events = self._serialize(service, events)
 
         if txn_id is None:
             logger.warning(
@@ -221,9 +238,12 @@ class ApplicationServiceApi(SimpleHttpClient):
 
         # Never send ephemeral events to appservices that do not support it
         if service.supports_ephemeral:
-            body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral}
+            body = {
+                "events": serialized_events,
+                "de.sorunome.msc2409.ephemeral": ephemeral,
+            }
         else:
-            body = {"events": events}
+            body = {"events": serialized_events}
 
         try:
             await self.put_json(
@@ -238,7 +258,7 @@ class ApplicationServiceApi(SimpleHttpClient):
                     [event.get("event_id") for event in events],
                 )
             sent_transactions_counter.labels(service.id).inc()
-            sent_events_counter.labels(service.id).inc(len(events))
+            sent_events_counter.labels(service.id).inc(len(serialized_events))
             return True
         except CodeMessageException as e:
             logger.warning(
@@ -260,7 +280,9 @@ class ApplicationServiceApi(SimpleHttpClient):
         failed_transactions_counter.labels(service.id).inc()
         return False
 
-    def _serialize(self, service, events):
+    def _serialize(
+        self, service: "ApplicationService", events: Iterable[EventBase]
+    ) -> List[JsonDict]:
         time_now = self.clock.time_msec()
         return [
             serialize_event(
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 6a2ce99b55..185e3a5278 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -48,13 +48,19 @@ This is all tied together by the AppServiceScheduler which DIs the required
 components.
 """
 import logging
-from typing import List, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set
 
 from synapse.appservice import ApplicationService, ApplicationServiceState
+from synapse.appservice.api import ApplicationServiceApi
 from synapse.events import EventBase
 from synapse.logging.context import run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main import DataStore
 from synapse.types import JsonDict
+from synapse.util import Clock
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -72,7 +78,7 @@ class ApplicationServiceScheduler:
     case is a simple array.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.as_api = hs.get_application_service_api()
@@ -80,7 +86,7 @@ class ApplicationServiceScheduler:
         self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api)
         self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
 
-    async def start(self):
+    async def start(self) -> None:
         logger.info("Starting appservice scheduler")
 
         # check for any DOWN ASes and start recoverers for them.
@@ -91,12 +97,14 @@ class ApplicationServiceScheduler:
         for service in services:
             self.txn_ctrl.start_recoverer(service)
 
-    def submit_event_for_as(self, service: ApplicationService, event: EventBase):
+    def submit_event_for_as(
+        self, service: ApplicationService, event: EventBase
+    ) -> None:
         self.queuer.enqueue_event(service, event)
 
     def submit_ephemeral_events_for_as(
         self, service: ApplicationService, events: List[JsonDict]
-    ):
+    ) -> None:
         self.queuer.enqueue_ephemeral(service, events)
 
 
@@ -108,16 +116,18 @@ class _ServiceQueuer:
     appservice at a given time.
     """
 
-    def __init__(self, txn_ctrl, clock):
-        self.queued_events = {}  # dict of {service_id: [events]}
-        self.queued_ephemeral = {}  # dict of {service_id: [events]}
+    def __init__(self, txn_ctrl: "_TransactionController", clock: Clock):
+        # dict of {service_id: [events]}
+        self.queued_events: Dict[str, List[EventBase]] = {}
+        # dict of {service_id: [events]}
+        self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
 
         # the appservices which currently have a transaction in flight
-        self.requests_in_flight = set()
+        self.requests_in_flight: Set[str] = set()
         self.txn_ctrl = txn_ctrl
         self.clock = clock
 
-    def _start_background_request(self, service):
+    def _start_background_request(self, service: ApplicationService) -> None:
         # start a sender for this appservice if we don't already have one
         if service.id in self.requests_in_flight:
             return
@@ -126,15 +136,17 @@ class _ServiceQueuer:
             "as-sender-%s" % (service.id,), self._send_request, service
         )
 
-    def enqueue_event(self, service: ApplicationService, event: EventBase):
+    def enqueue_event(self, service: ApplicationService, event: EventBase) -> None:
         self.queued_events.setdefault(service.id, []).append(event)
         self._start_background_request(service)
 
-    def enqueue_ephemeral(self, service: ApplicationService, events: List[JsonDict]):
+    def enqueue_ephemeral(
+        self, service: ApplicationService, events: List[JsonDict]
+    ) -> None:
         self.queued_ephemeral.setdefault(service.id, []).extend(events)
         self._start_background_request(service)
 
-    async def _send_request(self, service: ApplicationService):
+    async def _send_request(self, service: ApplicationService) -> None:
         # sanity-check: we shouldn't get here if this service already has a sender
         # running.
         assert service.id not in self.requests_in_flight
@@ -168,20 +180,15 @@ class _TransactionController:
     if a transaction fails.
 
     (Note we have only have one of these in the homeserver.)
-
-    Args:
-        clock (synapse.util.Clock):
-        store (synapse.storage.DataStore):
-        as_api (synapse.appservice.api.ApplicationServiceApi):
     """
 
-    def __init__(self, clock, store, as_api):
+    def __init__(self, clock: Clock, store: DataStore, as_api: ApplicationServiceApi):
         self.clock = clock
         self.store = store
         self.as_api = as_api
 
         # map from service id to recoverer instance
-        self.recoverers = {}
+        self.recoverers: Dict[str, "_Recoverer"] = {}
 
         # for UTs
         self.RECOVERER_CLASS = _Recoverer
@@ -191,7 +198,7 @@ class _TransactionController:
         service: ApplicationService,
         events: List[EventBase],
         ephemeral: Optional[List[JsonDict]] = None,
-    ):
+    ) -> None:
         try:
             txn = await self.store.create_appservice_txn(
                 service=service, events=events, ephemeral=ephemeral or []
@@ -207,7 +214,7 @@ class _TransactionController:
             logger.exception("Error creating appservice transaction")
             run_in_background(self._on_txn_fail, service)
 
-    async def on_recovered(self, recoverer):
+    async def on_recovered(self, recoverer: "_Recoverer") -> None:
         logger.info(
             "Successfully recovered application service AS ID %s", recoverer.service.id
         )
@@ -217,18 +224,18 @@ class _TransactionController:
             recoverer.service, ApplicationServiceState.UP
         )
 
-    async def _on_txn_fail(self, service):
+    async def _on_txn_fail(self, service: ApplicationService) -> None:
         try:
             await self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
             self.start_recoverer(service)
         except Exception:
             logger.exception("Error starting AS recoverer")
 
-    def start_recoverer(self, service):
+    def start_recoverer(self, service: ApplicationService) -> None:
         """Start a Recoverer for the given service
 
         Args:
-            service (synapse.appservice.ApplicationService):
+            service:
         """
         logger.info("Starting recoverer for AS ID %s", service.id)
         assert service.id not in self.recoverers
@@ -257,7 +264,14 @@ class _Recoverer:
         callback (callable[_Recoverer]): called once the service recovers.
     """
 
-    def __init__(self, clock, store, as_api, service, callback):
+    def __init__(
+        self,
+        clock: Clock,
+        store: DataStore,
+        as_api: ApplicationServiceApi,
+        service: ApplicationService,
+        callback: Callable[["_Recoverer"], Awaitable[None]],
+    ):
         self.clock = clock
         self.store = store
         self.as_api = as_api
@@ -265,8 +279,8 @@ class _Recoverer:
         self.callback = callback
         self.backoff_counter = 1
 
-    def recover(self):
-        def _retry():
+    def recover(self) -> None:
+        def _retry() -> None:
             run_as_background_process(
                 "as-recoverer-%s" % (self.service.id,), self.retry
             )
@@ -275,13 +289,13 @@ class _Recoverer:
         logger.info("Scheduling retries on %s in %fs", self.service.id, delay)
         self.clock.call_later(delay, _retry)
 
-    def _backoff(self):
+    def _backoff(self) -> None:
         # cap the backoff to be around 8.5min => (2^9) = 512 secs
         if self.backoff_counter < 9:
             self.backoff_counter += 1
         self.recover()
 
-    async def retry(self):
+    async def retry(self) -> None:
         logger.info("Starting retries on %s", self.service.id)
         try:
             while True:
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index e4bb7224a4..7fad2e0422 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -147,8 +147,7 @@ def _load_appservice(
     # protocols check
     protocols = as_info.get("protocols")
     if protocols:
-        # Because strings are lists in python
-        if isinstance(protocols, str) or not isinstance(protocols, list):
+        if not isinstance(protocols, list):
             raise KeyError("Optional 'protocols' must be a list if present.")
         for p in protocols:
             if not isinstance(p, str):
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index f386b5e128..ba2a2bfd64 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -16,13 +16,13 @@ from unittest.mock import Mock
 
 from twisted.internet import defer
 
-from synapse.appservice import ApplicationService
+from synapse.appservice import ApplicationService, Namespace
 
 from tests import unittest
 
 
-def _regex(regex, exclusive=True):
-    return {"regex": re.compile(regex), "exclusive": exclusive}
+def _regex(regex: str, exclusive: bool = True) -> Namespace:
+    return Namespace(exclusive, None, re.compile(regex))
 
 
 class ApplicationServiceTestCase(unittest.TestCase):
@@ -33,11 +33,6 @@ class ApplicationServiceTestCase(unittest.TestCase):
             url="some_url",
             token="some_token",
             hostname="matrix.org",  # only used by get_groups_for_user
-            namespaces={
-                ApplicationService.NS_USERS: [],
-                ApplicationService.NS_ROOMS: [],
-                ApplicationService.NS_ALIASES: [],
-            },
         )
         self.event = Mock(
             type="m.something", room_id="!foo:bar", sender="@someone:somewhere"