summary refs log tree commit diff
path: root/synapse/appservice/api.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/appservice/api.py')
-rw-r--r--synapse/appservice/api.py47
1 files changed, 26 insertions, 21 deletions
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 57174da021..e72a0b9ac0 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -13,14 +13,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-
-from six.moves import urllib
+import urllib
 
 from prometheus_client import Counter
 
 from twisted.internet import defer
 
-from synapse.api.constants import ThirdPartyEntityKind
+from synapse.api.constants import EventTypes, ThirdPartyEntityKind
 from synapse.api.errors import CodeMessageException
 from synapse.events.utils import serialize_event
 from synapse.http.client import SimpleHttpClient
@@ -94,14 +93,12 @@ class ApplicationServiceApi(SimpleHttpClient):
             hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
         )
 
-    @defer.inlineCallbacks
-    def query_user(self, service, user_id):
+    async def query_user(self, service, user_id):
         if service.url is None:
             return False
         uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
-        response = None
         try:
-            response = yield self.get_json(uri, {"access_token": service.hs_token})
+            response = await self.get_json(uri, {"access_token": service.hs_token})
             if response is not None:  # just an empty json object
                 return True
         except CodeMessageException as e:
@@ -112,14 +109,12 @@ class ApplicationServiceApi(SimpleHttpClient):
             logger.warning("query_user to %s threw exception %s", uri, ex)
         return False
 
-    @defer.inlineCallbacks
-    def query_alias(self, service, alias):
+    async def query_alias(self, service, alias):
         if service.url is None:
             return False
         uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
-        response = None
         try:
-            response = yield self.get_json(uri, {"access_token": service.hs_token})
+            response = await self.get_json(uri, {"access_token": service.hs_token})
             if response is not None:  # just an empty json object
                 return True
         except CodeMessageException as e:
@@ -130,8 +125,7 @@ class ApplicationServiceApi(SimpleHttpClient):
             logger.warning("query_alias to %s threw exception %s", uri, ex)
         return False
 
-    @defer.inlineCallbacks
-    def query_3pe(self, service, kind, protocol, fields):
+    async def query_3pe(self, service, kind, protocol, fields):
         if kind == ThirdPartyEntityKind.USER:
             required_field = "userid"
         elif kind == ThirdPartyEntityKind.LOCATION:
@@ -148,7 +142,7 @@ class ApplicationServiceApi(SimpleHttpClient):
             urllib.parse.quote(protocol),
         )
         try:
-            response = yield self.get_json(uri, fields)
+            response = await self.get_json(uri, fields)
             if not isinstance(response, list):
                 logger.warning(
                     "query_3pe to %s returned an invalid response %r", uri, response
@@ -181,7 +175,7 @@ class ApplicationServiceApi(SimpleHttpClient):
                 urllib.parse.quote(protocol),
             )
             try:
-                info = yield self.get_json(uri, {})
+                info = yield defer.ensureDeferred(self.get_json(uri, {}))
 
                 if not _is_valid_3pe_metadata(info):
                     logger.warning(
@@ -204,12 +198,11 @@ class ApplicationServiceApi(SimpleHttpClient):
         key = (service.id, protocol)
         return self.protocol_meta_cache.wrap(key, _get)
 
-    @defer.inlineCallbacks
-    def push_bulk(self, service, events, txn_id=None):
+    async def push_bulk(self, service, events, txn_id=None):
         if service.url is None:
             return True
 
-        events = self._serialize(events)
+        events = self._serialize(service, events)
 
         if txn_id is None:
             logger.warning(
@@ -220,7 +213,7 @@ class ApplicationServiceApi(SimpleHttpClient):
 
         uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id))
         try:
-            yield self.put_json(
+            await self.put_json(
                 uri=uri,
                 json_body={"events": events},
                 args={"access_token": service.hs_token},
@@ -235,6 +228,18 @@ class ApplicationServiceApi(SimpleHttpClient):
         failed_transactions_counter.labels(service.id).inc()
         return False
 
-    def _serialize(self, events):
+    def _serialize(self, service, events):
         time_now = self.clock.time_msec()
-        return [serialize_event(e, time_now, as_client_event=True) for e in events]
+        return [
+            serialize_event(
+                e,
+                time_now,
+                as_client_event=True,
+                is_invite=(
+                    e.type == EventTypes.Member
+                    and e.membership == "invite"
+                    and service.is_interested_in_user(e.state_key)
+                ),
+            )
+            for e in events
+        ]