summary refs log tree commit diff
diff options
context:
space:
mode:
authorPaul "LeoNerd" Evans <paul@matrix.org>2016-08-18 16:09:50 +0100
committerPaul "LeoNerd" Evans <paul@matrix.org>2016-08-18 16:09:50 +0100
commit06964c4a0adabf7d983cbd0d2c6d83eba6fcaf79 (patch)
treecd3eb6fb5c8c5558aba941ddc55a0cfff1dd0703
parentRemove TODO note about request fields being strings - they're always strings (diff)
downloadsynapse-06964c4a0adabf7d983cbd0d2c6d83eba6fcaf79.tar.xz
Copypasta the 3PU support code to also do 3PL
-rw-r--r--synapse/appservice/api.py11
-rw-r--r--synapse/handlers/appservice.py33
-rw-r--r--synapse/rest/client/v2_alpha/thirdparty.py20
3 files changed, 61 insertions, 3 deletions
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index e05570cc8b..4ccb5c43c1 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -83,6 +83,17 @@ class ApplicationServiceApi(SimpleHttpClient):
             defer.returnValue([])
 
     @defer.inlineCallbacks
+    def query_3pl(self, service, protocol, fields):
+        uri = service.url + ("/3pl/%s" % urllib.quote(protocol))
+        response = None
+        try:
+            response = yield self.get_json(uri, fields)
+            defer.returnValue(response)
+        except Exception as ex:
+            logger.warning("query_3pl to %s threw exception %s", uri, ex)
+            defer.returnValue([])
+
+    @defer.inlineCallbacks
     def push_bulk(self, service, events, txn_id=None):
         events = self._serialize(events)
 
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 5ed694e711..72c36615df 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -34,11 +34,11 @@ def log_failure(failure):
         )
     )
 
-def _is_valid_3pu_result(r):
+def _is_valid_3pentity_result(r, field):
     if not isinstance(r, dict):
         return False
 
-    for k in ("userid", "protocol"):
+    for k in (field, "protocol"):
         if k not in r:
             return False
         if not isinstance(r[k], str):
@@ -185,7 +185,34 @@ class ApplicationServicesHandler(object):
             if not isinstance(result, list):
                 continue
             for r in result:
-                if _is_valid_3pu_result(r):
+                if _is_valid_3pentity_result(r, field="userid"):
+                    ret.append(r)
+                else:
+                    logger.warn("Application service returned an " +
+                                "invalid result %r", r)
+
+        defer.returnValue(ret)
+
+    @defer.inlineCallbacks
+    def query_3pl(self, protocol, fields):
+        services = yield self._get_services_for_3pn(protocol)
+
+        deferreds = []
+        for service in services:
+            deferreds.append(self.appservice_api.query_3pl(
+                service, protocol, fields
+            ))
+
+        results = yield defer.DeferredList(deferreds, consumeErrors=True)
+
+        ret = []
+        for (success, result) in results:
+            if not success:
+                continue
+            if not isinstance(result, list):
+                continue
+            for r in result:
+                if _is_valid_3pentity_result(r, field="alias"):
                     ret.append(r)
                 else:
                     logger.warn("Application service returned an " +
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index bce104c545..eec08425e6 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -43,5 +43,25 @@ class ThirdPartyUserServlet(RestServlet):
         defer.returnValue((200, results))
 
 
+class ThirdPartyLocationServlet(RestServlet):
+    PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$",
+                                  releases=())
+
+    def __init__(self, hs):
+        super(ThirdPartyLocationServlet, self).__init__()
+
+        self.appservice_handler = hs.get_application_service_handler()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, protocol):
+        fields = request.args
+        del fields["access_token"]
+
+        results = yield self.appservice_handler.query_3pl(protocol, fields)
+
+        defer.returnValue((200, results))
+
+
 def register_servlets(hs, http_server):
     ThirdPartyUserServlet(hs).register(http_server)
+    ThirdPartyLocationServlet(hs).register(http_server)