summary refs log tree commit diff
diff options
context:
space:
mode:
authorPaul "LeoNerd" Evans <paul@matrix.org>2016-08-18 17:33:56 +0100
committerPaul "LeoNerd" Evans <paul@matrix.org>2016-08-18 17:33:56 +0100
commit65201631a407b71087bb52da8b591e0975c463ec (patch)
treee7077e567fddbc9a0ac30204e19dc1f0bc9e74e2
parentMore warnings about invalid results from AS 3PE query (diff)
downloadsynapse-65201631a407b71087bb52da8b591e0975c463ec.tar.xz
Move validation logic for AS 3PE query response into ApplicationServiceApi class, to keep the handler logic neater
-rw-r--r--synapse/appservice/api.py43
-rw-r--r--synapse/handlers/appservice.py46
2 files changed, 44 insertions, 45 deletions
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index dd5e762e0d..066127b666 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -25,6 +25,28 @@ import urllib
 logger = logging.getLogger(__name__)
 
 
+def _is_valid_3pe_result(r, field):
+    if not isinstance(r, dict):
+        return False
+
+    for k in (field, "protocol"):
+        if k not in r:
+            return False
+        if not isinstance(r[k], str):
+            return False
+
+    if "fields" not in r:
+        return False
+    fields = r["fields"]
+    if not isinstance(fields, dict):
+        return False
+    for k in fields.keys():
+        if not isinstance(fields[k], str):
+            return False
+
+    return True
+
+
 class ApplicationServiceApi(SimpleHttpClient):
     """This class manages HS -> AS communications, including querying and
     pushing.
@@ -76,8 +98,10 @@ class ApplicationServiceApi(SimpleHttpClient):
     def query_3pe(self, service, kind, protocol, fields):
         if kind == ThirdPartyEntityKind.USER:
             uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
+            required_field = "userid"
         elif kind == ThirdPartyEntityKind.LOCATION:
             uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
+            required_field = "alias"
         else:
             raise ValueError(
                 "Unrecognised 'kind' argument %r to query_3pe()", kind
@@ -85,7 +109,24 @@ class ApplicationServiceApi(SimpleHttpClient):
 
         try:
             response = yield self.get_json(uri, fields)
-            defer.returnValue(response)
+            if not isinstance(response, list):
+                logger.warning(
+                    "query_3pe to %s returned an invalid response %r",
+                    uri, response
+                )
+                defer.returnValue([])
+
+            ret = []
+            for r in response:
+                if _is_valid_3pe_result(r, field=required_field):
+                    ret.append(r)
+                else:
+                    logger.warning(
+                        "query_3pe to %s returned an invalid result %r",
+                        uri, r
+                    )
+
+            defer.returnValue(ret)
         except Exception as ex:
             logger.warning("query_3pe to %s threw exception %s", uri, ex)
             defer.returnValue([])
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 6f162a3c00..18dca462a9 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -18,7 +18,6 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes
 from synapse.util.metrics import Measure
 from synapse.util.logcontext import preserve_fn
-from synapse.types import ThirdPartyEntityKind
 
 import logging
 
@@ -36,28 +35,6 @@ def log_failure(failure):
     )
 
 
-def _is_valid_3pentity_result(r, field):
-    if not isinstance(r, dict):
-        return False
-
-    for k in (field, "protocol"):
-        if k not in r:
-            return False
-        if not isinstance(r[k], str):
-            return False
-
-    if "fields" not in r:
-        return False
-    fields = r["fields"]
-    if not isinstance(fields, dict):
-        return False
-    for k in fields.keys():
-        if not isinstance(fields[k], str):
-            return False
-
-    return True
-
-
 class ApplicationServicesHandler(object):
 
     def __init__(self, hs):
@@ -178,29 +155,10 @@ class ApplicationServicesHandler(object):
             for service in services
         ], consumeErrors=True)
 
-        required_field = (
-            "userid" if kind == ThirdPartyEntityKind.USER else
-            "alias" if kind == ThirdPartyEntityKind.LOCATION else
-            None
-        )
-
         ret = []
         for (success, result) in results:
-            if not success:
-                logger.warn("Application service failed %r", result)
-                continue
-            if not isinstance(result, list):
-                logger.warn(
-                    "Application service returned an invalid response %r", result
-                )
-                continue
-            for r in result:
-                if _is_valid_3pentity_result(r, field=required_field):
-                    ret.append(r)
-                else:
-                    logger.warn(
-                        "Application service returned an invalid result %r", r
-                    )
+            if success:
+                ret.extend(result)
 
         defer.returnValue(ret)