summary refs log tree commit diff
path: root/synapse/appservice/api.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/appservice/api.py')
-rw-r--r--synapse/appservice/api.py48
1 files changed, 35 insertions, 13 deletions
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index f51b636417..def4424af0 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -12,8 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import urllib
-from typing import TYPE_CHECKING, List, Optional, Tuple
+import urllib.parse
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
 
 from prometheus_client import Counter
 
@@ -53,7 +53,7 @@ HOUR_IN_MS = 60 * 60 * 1000
 APP_SERVICE_PREFIX = "/_matrix/app/unstable"
 
 
-def _is_valid_3pe_metadata(info):
+def _is_valid_3pe_metadata(info: JsonDict) -> bool:
     if "instances" not in info:
         return False
     if not isinstance(info["instances"], list):
@@ -61,7 +61,7 @@ def _is_valid_3pe_metadata(info):
     return True
 
 
-def _is_valid_3pe_result(r, field):
+def _is_valid_3pe_result(r: JsonDict, field: str) -> bool:
     if not isinstance(r, dict):
         return False
 
@@ -93,9 +93,13 @@ class ApplicationServiceApi(SimpleHttpClient):
             hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
         )
 
-    async def query_user(self, service, user_id):
+    async def query_user(self, service: "ApplicationService", user_id: str) -> bool:
         if service.url is None:
             return False
+
+        # This is required by the configuration.
+        assert service.hs_token is not None
+
         uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
         try:
             response = await self.get_json(uri, {"access_token": service.hs_token})
@@ -109,9 +113,13 @@ class ApplicationServiceApi(SimpleHttpClient):
             logger.warning("query_user to %s threw exception %s", uri, ex)
         return False
 
-    async def query_alias(self, service, alias):
+    async def query_alias(self, service: "ApplicationService", alias: str) -> bool:
         if service.url is None:
             return False
+
+        # This is required by the configuration.
+        assert service.hs_token is not None
+
         uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
         try:
             response = await self.get_json(uri, {"access_token": service.hs_token})
@@ -125,7 +133,13 @@ class ApplicationServiceApi(SimpleHttpClient):
             logger.warning("query_alias to %s threw exception %s", uri, ex)
         return False
 
-    async def query_3pe(self, service, kind, protocol, fields):
+    async def query_3pe(
+        self,
+        service: "ApplicationService",
+        kind: str,
+        protocol: str,
+        fields: Dict[bytes, List[bytes]],
+    ) -> List[JsonDict]:
         if kind == ThirdPartyEntityKind.USER:
             required_field = "userid"
         elif kind == ThirdPartyEntityKind.LOCATION:
@@ -205,11 +219,14 @@ class ApplicationServiceApi(SimpleHttpClient):
         events: List[EventBase],
         ephemeral: List[JsonDict],
         txn_id: Optional[int] = None,
-    ):
+    ) -> bool:
         if service.url is None:
             return True
 
-        events = self._serialize(service, events)
+        # This is required by the configuration.
+        assert service.hs_token is not None
+
+        serialized_events = self._serialize(service, events)
 
         if txn_id is None:
             logger.warning(
@@ -221,9 +238,12 @@ class ApplicationServiceApi(SimpleHttpClient):
 
         # Never send ephemeral events to appservices that do not support it
         if service.supports_ephemeral:
-            body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral}
+            body = {
+                "events": serialized_events,
+                "de.sorunome.msc2409.ephemeral": ephemeral,
+            }
         else:
-            body = {"events": events}
+            body = {"events": serialized_events}
 
         try:
             await self.put_json(
@@ -238,7 +258,7 @@ class ApplicationServiceApi(SimpleHttpClient):
                     [event.get("event_id") for event in events],
                 )
             sent_transactions_counter.labels(service.id).inc()
-            sent_events_counter.labels(service.id).inc(len(events))
+            sent_events_counter.labels(service.id).inc(len(serialized_events))
             return True
         except CodeMessageException as e:
             logger.warning(
@@ -260,7 +280,9 @@ class ApplicationServiceApi(SimpleHttpClient):
         failed_transactions_counter.labels(service.id).inc()
         return False
 
-    def _serialize(self, service, events):
+    def _serialize(
+        self, service: "ApplicationService", events: Iterable[EventBase]
+    ) -> List[JsonDict]:
         time_now = self.clock.time_msec()
         return [
             serialize_event(