summary refs log tree commit diff
path: root/synapse/appservice/api.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/appservice/api.py')
-rw-r--r--synapse/appservice/api.py64
1 files changed, 36 insertions, 28 deletions
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 57174da021..bb6fa8299a 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -13,20 +13,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-
-from six.moves import urllib
+import urllib
+from typing import TYPE_CHECKING, Optional
 
 from prometheus_client import Counter
 
-from twisted.internet import defer
-
-from synapse.api.constants import ThirdPartyEntityKind
+from synapse.api.constants import EventTypes, ThirdPartyEntityKind
 from synapse.api.errors import CodeMessageException
 from synapse.events.utils import serialize_event
 from synapse.http.client import SimpleHttpClient
-from synapse.types import ThirdPartyInstanceID
+from synapse.types import JsonDict, ThirdPartyInstanceID
 from synapse.util.caches.response_cache import ResponseCache
 
+if TYPE_CHECKING:
+    from synapse.appservice import ApplicationService
+
 logger = logging.getLogger(__name__)
 
 sent_transactions_counter = Counter(
@@ -94,14 +95,12 @@ class ApplicationServiceApi(SimpleHttpClient):
             hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
         )
 
-    @defer.inlineCallbacks
-    def query_user(self, service, user_id):
+    async def query_user(self, service, user_id):
         if service.url is None:
             return False
         uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
-        response = None
         try:
-            response = yield self.get_json(uri, {"access_token": service.hs_token})
+            response = await self.get_json(uri, {"access_token": service.hs_token})
             if response is not None:  # just an empty json object
                 return True
         except CodeMessageException as e:
@@ -112,14 +111,12 @@ class ApplicationServiceApi(SimpleHttpClient):
             logger.warning("query_user to %s threw exception %s", uri, ex)
         return False
 
-    @defer.inlineCallbacks
-    def query_alias(self, service, alias):
+    async def query_alias(self, service, alias):
         if service.url is None:
             return False
         uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
-        response = None
         try:
-            response = yield self.get_json(uri, {"access_token": service.hs_token})
+            response = await self.get_json(uri, {"access_token": service.hs_token})
             if response is not None:  # just an empty json object
                 return True
         except CodeMessageException as e:
@@ -130,8 +127,7 @@ class ApplicationServiceApi(SimpleHttpClient):
             logger.warning("query_alias to %s threw exception %s", uri, ex)
         return False
 
-    @defer.inlineCallbacks
-    def query_3pe(self, service, kind, protocol, fields):
+    async def query_3pe(self, service, kind, protocol, fields):
         if kind == ThirdPartyEntityKind.USER:
             required_field = "userid"
         elif kind == ThirdPartyEntityKind.LOCATION:
@@ -148,7 +144,7 @@ class ApplicationServiceApi(SimpleHttpClient):
             urllib.parse.quote(protocol),
         )
         try:
-            response = yield self.get_json(uri, fields)
+            response = await self.get_json(uri, fields)
             if not isinstance(response, list):
                 logger.warning(
                     "query_3pe to %s returned an invalid response %r", uri, response
@@ -169,19 +165,20 @@ class ApplicationServiceApi(SimpleHttpClient):
             logger.warning("query_3pe to %s threw exception %s", uri, ex)
             return []
 
-    def get_3pe_protocol(self, service, protocol):
+    async def get_3pe_protocol(
+        self, service: "ApplicationService", protocol: str
+    ) -> Optional[JsonDict]:
         if service.url is None:
             return {}
 
-        @defer.inlineCallbacks
-        def _get():
+        async def _get() -> Optional[JsonDict]:
             uri = "%s%s/thirdparty/protocol/%s" % (
                 service.url,
                 APP_SERVICE_PREFIX,
                 urllib.parse.quote(protocol),
             )
             try:
-                info = yield self.get_json(uri, {})
+                info = await self.get_json(uri, {})
 
                 if not _is_valid_3pe_metadata(info):
                     logger.warning(
@@ -202,14 +199,13 @@ class ApplicationServiceApi(SimpleHttpClient):
                 return None
 
         key = (service.id, protocol)
-        return self.protocol_meta_cache.wrap(key, _get)
+        return await self.protocol_meta_cache.wrap(key, _get)
 
-    @defer.inlineCallbacks
-    def push_bulk(self, service, events, txn_id=None):
+    async def push_bulk(self, service, events, txn_id=None):
         if service.url is None:
             return True
 
-        events = self._serialize(events)
+        events = self._serialize(service, events)
 
         if txn_id is None:
             logger.warning(
@@ -220,7 +216,7 @@ class ApplicationServiceApi(SimpleHttpClient):
 
         uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id))
         try:
-            yield self.put_json(
+            await self.put_json(
                 uri=uri,
                 json_body={"events": events},
                 args={"access_token": service.hs_token},
@@ -235,6 +231,18 @@ class ApplicationServiceApi(SimpleHttpClient):
         failed_transactions_counter.labels(service.id).inc()
         return False
 
-    def _serialize(self, events):
+    def _serialize(self, service, events):
         time_now = self.clock.time_msec()
-        return [serialize_event(e, time_now, as_client_event=True) for e in events]
+        return [
+            serialize_event(
+                e,
+                time_now,
+                as_client_event=True,
+                is_invite=(
+                    e.type == EventTypes.Member
+                    and e.membership == "invite"
+                    and service.is_interested_in_user(e.state_key)
+                ),
+            )
+            for e in events
+        ]