summary refs log tree commit diff
path: root/synapse/appservice
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/appservice')
-rw-r--r--synapse/appservice/__init__.py90
-rw-r--r--synapse/appservice/api.py60
-rw-r--r--synapse/appservice/scheduler.py67
3 files changed, 143 insertions, 74 deletions
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index f7178ea0d3..bde9b51b2e 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -14,6 +14,8 @@
 # limitations under the License.
 from synapse.api.constants import EventTypes
 
+from twisted.internet import defer
+
 import logging
 import re
 
@@ -79,13 +81,17 @@ class ApplicationService(object):
     NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
 
     def __init__(self, token, url=None, namespaces=None, hs_token=None,
-                 sender=None, id=None):
+                 sender=None, id=None, protocols=None):
         self.token = token
         self.url = url
         self.hs_token = hs_token
         self.sender = sender
         self.namespaces = self._check_namespaces(namespaces)
         self.id = id
+        if protocols:
+            self.protocols = set(protocols)
+        else:
+            self.protocols = set()
 
     def _check_namespaces(self, namespaces):
         # Sanity check that it is of the form:
@@ -138,65 +144,66 @@ class ApplicationService(object):
             return regex_obj["exclusive"]
         return False
 
-    def _matches_user(self, event, member_list):
-        if (hasattr(event, "sender") and
-                self.is_interested_in_user(event.sender)):
-            return True
+    @defer.inlineCallbacks
+    def _matches_user(self, event, store):
+        if not event:
+            defer.returnValue(False)
+
+        if self.is_interested_in_user(event.sender):
+            defer.returnValue(True)
         # also check m.room.member state key
-        if (hasattr(event, "type") and event.type == EventTypes.Member
-                and hasattr(event, "state_key")
-                and self.is_interested_in_user(event.state_key)):
-            return True
+        if (event.type == EventTypes.Member and
+                self.is_interested_in_user(event.state_key)):
+            defer.returnValue(True)
+
+        if not store:
+            defer.returnValue(False)
+
+        member_list = yield store.get_users_in_room(event.room_id)
+
         # check joined member events
         for user_id in member_list:
             if self.is_interested_in_user(user_id):
-                return True
-        return False
+                defer.returnValue(True)
+        defer.returnValue(False)
 
     def _matches_room_id(self, event):
         if hasattr(event, "room_id"):
             return self.is_interested_in_room(event.room_id)
         return False
 
-    def _matches_aliases(self, event, alias_list):
+    @defer.inlineCallbacks
+    def _matches_aliases(self, event, store):
+        if not store or not event:
+            defer.returnValue(False)
+
+        alias_list = yield store.get_aliases_for_room(event.room_id)
         for alias in alias_list:
             if self.is_interested_in_alias(alias):
-                return True
-        return False
+                defer.returnValue(True)
+        defer.returnValue(False)
 
-    def is_interested(self, event, restrict_to=None, aliases_for_event=None,
-                      member_list=None):
+    @defer.inlineCallbacks
+    def is_interested(self, event, store=None):
         """Check if this service is interested in this event.
 
         Args:
             event(Event): The event to check.
-            restrict_to(str): The namespace to restrict regex tests to.
-            aliases_for_event(list): A list of all the known room aliases for
-            this event.
-            member_list(list): A list of all joined user_ids in this room.
+            store(DataStore)
         Returns:
             bool: True if this service would like to know about this event.
         """
-        if aliases_for_event is None:
-            aliases_for_event = []
-        if member_list is None:
-            member_list = []
-
-        if restrict_to and restrict_to not in ApplicationService.NS_LIST:
-            # this is a programming error, so fail early and raise a general
-            # exception
-            raise Exception("Unexpected restrict_to value: %s". restrict_to)
-
-        if not restrict_to:
-            return (self._matches_user(event, member_list)
-                    or self._matches_aliases(event, aliases_for_event)
-                    or self._matches_room_id(event))
-        elif restrict_to == ApplicationService.NS_ALIASES:
-            return self._matches_aliases(event, aliases_for_event)
-        elif restrict_to == ApplicationService.NS_ROOMS:
-            return self._matches_room_id(event)
-        elif restrict_to == ApplicationService.NS_USERS:
-            return self._matches_user(event, member_list)
+        # Do cheap checks first
+        if self._matches_room_id(event):
+            defer.returnValue(True)
+
+        if (yield self._matches_aliases(event, store)):
+            defer.returnValue(True)
+
+        if (yield self._matches_user(event, store)):
+            defer.returnValue(True)
+
+        defer.returnValue(False)
 
     def is_interested_in_user(self, user_id):
         return (
@@ -216,6 +223,9 @@ class ApplicationService(object):
             or user_id == self.sender
         )
 
+    def is_interested_in_protocol(self, protocol):
+        return protocol in self.protocols
+
     def is_exclusive_alias(self, alias):
         return self._is_exclusive(ApplicationService.NS_ALIASES, alias)
 
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 6da6a1b62e..066127b666 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -17,6 +17,7 @@ from twisted.internet import defer
 from synapse.api.errors import CodeMessageException
 from synapse.http.client import SimpleHttpClient
 from synapse.events.utils import serialize_event
+from synapse.types import ThirdPartyEntityKind
 
 import logging
 import urllib
@@ -24,6 +25,28 @@ import urllib
 logger = logging.getLogger(__name__)
 
 
+def _is_valid_3pe_result(r, field):
+    if not isinstance(r, dict):
+        return False
+
+    for k in (field, "protocol"):
+        if k not in r:
+            return False
+        if not isinstance(r[k], str):
+            return False
+
+    if "fields" not in r:
+        return False
+    fields = r["fields"]
+    if not isinstance(fields, dict):
+        return False
+    for k in fields.keys():
+        if not isinstance(fields[k], str):
+            return False
+
+    return True
+
+
 class ApplicationServiceApi(SimpleHttpClient):
     """This class manages HS -> AS communications, including querying and
     pushing.
@@ -72,6 +95,43 @@ class ApplicationServiceApi(SimpleHttpClient):
         defer.returnValue(False)
 
     @defer.inlineCallbacks
+    def query_3pe(self, service, kind, protocol, fields):
+        if kind == ThirdPartyEntityKind.USER:
+            uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
+            required_field = "userid"
+        elif kind == ThirdPartyEntityKind.LOCATION:
+            uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
+            required_field = "alias"
+        else:
+            raise ValueError(
+                "Unrecognised 'kind' argument %r to query_3pe()", kind
+            )
+
+        try:
+            response = yield self.get_json(uri, fields)
+            if not isinstance(response, list):
+                logger.warning(
+                    "query_3pe to %s returned an invalid response %r",
+                    uri, response
+                )
+                defer.returnValue([])
+
+            ret = []
+            for r in response:
+                if _is_valid_3pe_result(r, field=required_field):
+                    ret.append(r)
+                else:
+                    logger.warning(
+                        "query_3pe to %s returned an invalid result %r",
+                        uri, r
+                    )
+
+            defer.returnValue(ret)
+        except Exception as ex:
+            logger.warning("query_3pe to %s threw exception %s", uri, ex)
+            defer.returnValue([])
+
+    @defer.inlineCallbacks
     def push_bulk(self, service, events, txn_id=None):
         events = self._serialize(events)
 
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 9afc8fd754..68a9de17b8 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -48,9 +48,12 @@ UP & quit           +---------- YES                       SUCCESS
 This is all tied together by the AppServiceScheduler which DIs the required
 components.
 """
+from twisted.internet import defer
 
 from synapse.appservice import ApplicationServiceState
-from twisted.internet import defer
+from synapse.util.logcontext import preserve_fn
+from synapse.util.metrics import Measure
+
 import logging
 
 logger = logging.getLogger(__name__)
@@ -73,7 +76,7 @@ class ApplicationServiceScheduler(object):
         self.txn_ctrl = _TransactionController(
             self.clock, self.store, self.as_api, create_recoverer
         )
-        self.queuer = _ServiceQueuer(self.txn_ctrl)
+        self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
 
     @defer.inlineCallbacks
     def start(self):
@@ -94,38 +97,36 @@ class _ServiceQueuer(object):
     this schedules any other events in the queue to run.
     """
 
-    def __init__(self, txn_ctrl):
+    def __init__(self, txn_ctrl, clock):
         self.queued_events = {}  # dict of {service_id: [events]}
-        self.pending_requests = {}  # dict of {service_id: Deferred}
+        self.requests_in_flight = set()
         self.txn_ctrl = txn_ctrl
+        self.clock = clock
 
     def enqueue(self, service, event):
         # if this service isn't being sent something
-        if not self.pending_requests.get(service.id):
-            self._send_request(service, [event])
-        else:
-            # add to queue for this service
-            if service.id not in self.queued_events:
-                self.queued_events[service.id] = []
-            self.queued_events[service.id].append(event)
-
-    def _send_request(self, service, events):
-        # send request and add callbacks
-        d = self.txn_ctrl.send(service, events)
-        d.addBoth(self._on_request_finish)
-        d.addErrback(self._on_request_fail)
-        self.pending_requests[service.id] = d
-
-    def _on_request_finish(self, service):
-        self.pending_requests[service.id] = None
-        # if there are queued events, then send them.
-        if (service.id in self.queued_events
-                and len(self.queued_events[service.id]) > 0):
-            self._send_request(service, self.queued_events[service.id])
-            self.queued_events[service.id] = []
-
-    def _on_request_fail(self, err):
-        logger.error("AS request failed: %s", err)
+        self.queued_events.setdefault(service.id, []).append(event)
+        preserve_fn(self._send_request)(service)
+
+    @defer.inlineCallbacks
+    def _send_request(self, service):
+        if service.id in self.requests_in_flight:
+            return
+
+        self.requests_in_flight.add(service.id)
+        try:
+            while True:
+                events = self.queued_events.pop(service.id, [])
+                if not events:
+                    return
+
+                with Measure(self.clock, "servicequeuer.send"):
+                    try:
+                        yield self.txn_ctrl.send(service, events)
+                    except:
+                        logger.exception("AS request failed")
+        finally:
+            self.requests_in_flight.discard(service.id)
 
 
 class _TransactionController(object):
@@ -149,14 +150,12 @@ class _TransactionController(object):
             if service_is_up:
                 sent = yield txn.send(self.as_api)
                 if sent:
-                    txn.complete(self.store)
+                    yield txn.complete(self.store)
                 else:
-                    self._start_recoverer(service)
+                    preserve_fn(self._start_recoverer)(service)
         except Exception as e:
             logger.exception(e)
-            self._start_recoverer(service)
-        # request has finished
-        defer.returnValue(service)
+            preserve_fn(self._start_recoverer)(service)
 
     @defer.inlineCallbacks
     def on_recovered(self, recoverer):