summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPaul "LeoNerd" Evans <paul@matrix.org>2016-08-18 17:19:55 +0100
committerPaul "LeoNerd" Evans <paul@matrix.org>2016-08-18 17:19:55 +0100
commitb515f844ee07c7d6aa1d7e56faa8b65d282e9341 (patch)
tree7d154e0c825e687fe19eed42998182aa07847d1b /synapse
parentMinor syntax neatenings (diff)
downloadsynapse-b515f844ee07c7d6aa1d7e56faa8b65d282e9341.tar.xz
Avoid so much copypasta between 3PU and 3PL query by unifying around a ThirdPartyEntityKind enumeration
Diffstat (limited to 'synapse')
-rw-r--r--synapse/appservice/api.py25
-rw-r--r--synapse/handlers/appservice.py35
-rw-r--r--synapse/rest/client/v2_alpha/thirdparty.py9
-rw-r--r--synapse/types.py7
4 files changed, 34 insertions, 42 deletions
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index d4cad1b1ed..dd5e762e0d 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -17,6 +17,7 @@ from twisted.internet import defer
 from synapse.api.errors import CodeMessageException
 from synapse.http.client import SimpleHttpClient
 from synapse.events.utils import serialize_event
+from synapse.types import ThirdPartyEntityKind
 
 import logging
 import urllib
@@ -72,25 +73,21 @@ class ApplicationServiceApi(SimpleHttpClient):
         defer.returnValue(False)
 
     @defer.inlineCallbacks
-    def query_3pu(self, service, protocol, fields):
-        uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
-        response = None
-        try:
-            response = yield self.get_json(uri, fields)
-            defer.returnValue(response)
-        except Exception as ex:
-            logger.warning("query_3pu to %s threw exception %s", uri, ex)
-            defer.returnValue([])
+    def query_3pe(self, service, kind, protocol, fields):
+        if kind == ThirdPartyEntityKind.USER:
+            uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
+        elif kind == ThirdPartyEntityKind.LOCATION:
+            uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
+        else:
+            raise ValueError(
+                "Unrecognised 'kind' argument %r to query_3pe()", kind
+            )
 
-    @defer.inlineCallbacks
-    def query_3pl(self, service, protocol, fields):
-        uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
-        response = None
         try:
             response = yield self.get_json(uri, fields)
             defer.returnValue(response)
         except Exception as ex:
-            logger.warning("query_3pl to %s threw exception %s", uri, ex)
+            logger.warning("query_3pe to %s threw exception %s", uri, ex)
             defer.returnValue([])
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 03452f6bb0..52c127d2c1 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes
 from synapse.util.metrics import Measure
 from synapse.util.logcontext import preserve_fn
+from synapse.types import ThirdPartyEntityKind
 
 import logging
 
@@ -169,37 +170,19 @@ class ApplicationServicesHandler(object):
                 defer.returnValue(result)
 
     @defer.inlineCallbacks
-    def query_3pu(self, protocol, fields):
+    def query_3pe(self, kind, protocol, fields):
         services = yield self._get_services_for_3pn(protocol)
 
         results = yield defer.DeferredList([
-            self.appservice_api.query_3pu(service, protocol, fields)
+            self.appservice_api.query_3pe(service, kind, protocol, fields)
             for service in services
         ], consumeErrors=True)
 
-        ret = []
-        for (success, result) in results:
-            if not success:
-                continue
-            if not isinstance(result, list):
-                continue
-            for r in result:
-                if _is_valid_3pentity_result(r, field="userid"):
-                    ret.append(r)
-                else:
-                    logger.warn("Application service returned an " +
-                                "invalid result %r", r)
-
-        defer.returnValue(ret)
-
-    @defer.inlineCallbacks
-    def query_3pl(self, protocol, fields):
-        services = yield self._get_services_for_3pn(protocol)
-
-        results = yield defer.DeferredList([
-            self.appservice_api.query_3pl(service, protocol, fields)
-            for service in services
-        ], consumeErrors=True)
+        required_field = (
+            "userid" if kind == ThirdPartyEntityKind.USER else
+            "alias" if kind == ThirdPartyEntityKind.LOCATION else
+            None
+        )
 
         ret = []
         for (success, result) in results:
@@ -208,7 +191,7 @@ class ApplicationServicesHandler(object):
             if not isinstance(result, list):
                 continue
             for r in result:
-                if _is_valid_3pentity_result(r, field="alias"):
+                if _is_valid_3pentity_result(r, field=required_field):
                     ret.append(r)
                 else:
                     logger.warn("Application service returned an " +
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index d229e4b818..9abca3a8ad 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -19,6 +19,7 @@ import logging
 from twisted.internet import defer
 
 from synapse.http.servlet import RestServlet
+from synapse.types import ThirdPartyEntityKind
 from ._base import client_v2_patterns
 
 logger = logging.getLogger(__name__)
@@ -41,7 +42,9 @@ class ThirdPartyUserServlet(RestServlet):
         fields = request.args
         del fields["access_token"]
 
-        results = yield self.appservice_handler.query_3pu(protocol, fields)
+        results = yield self.appservice_handler.query_3pe(
+            ThirdPartyEntityKind.USER, protocol, fields
+        )
 
         defer.returnValue((200, results))
 
@@ -63,7 +66,9 @@ class ThirdPartyLocationServlet(RestServlet):
         fields = request.args
         del fields["access_token"]
 
-        results = yield self.appservice_handler.query_3pl(protocol, fields)
+        results = yield self.appservice_handler.query_3pe(
+            ThirdPartyEntityKind.LOCATION, protocol, fields
+        )
 
         defer.returnValue((200, results))
 
diff --git a/synapse/types.py b/synapse/types.py
index 5349b0c450..fd17ecbbe0 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -269,3 +269,10 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
             return "t%d-%d" % (self.topological, self.stream)
         else:
             return "s%d" % (self.stream,)
+
+
+# Some arbitrary constants used for internal API enumerations. Don't rely on
+# exact values; always pass or compare symbolically
+class ThirdPartyEntityKind(object):
+    USER = 'user'
+    LOCATION = 'location'