diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index dd5e762e0d..066127b666 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -25,6 +25,28 @@ import urllib
logger = logging.getLogger(__name__)
+def _is_valid_3pe_result(r, field):
+ if not isinstance(r, dict):
+ return False
+
+ for k in (field, "protocol"):
+ if k not in r:
+ return False
+ if not isinstance(r[k], str):
+ return False
+
+ if "fields" not in r:
+ return False
+ fields = r["fields"]
+ if not isinstance(fields, dict):
+ return False
+ for k in fields.keys():
+ if not isinstance(fields[k], str):
+ return False
+
+ return True
+
+
class ApplicationServiceApi(SimpleHttpClient):
"""This class manages HS -> AS communications, including querying and
pushing.
@@ -76,8 +98,10 @@ class ApplicationServiceApi(SimpleHttpClient):
def query_3pe(self, service, kind, protocol, fields):
if kind == ThirdPartyEntityKind.USER:
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
+ required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION:
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
+ required_field = "alias"
else:
raise ValueError(
"Unrecognised 'kind' argument %r to query_3pe()", kind
@@ -85,7 +109,24 @@ class ApplicationServiceApi(SimpleHttpClient):
try:
response = yield self.get_json(uri, fields)
- defer.returnValue(response)
+ if not isinstance(response, list):
+ logger.warning(
+ "query_3pe to %s returned an invalid response %r",
+ uri, response
+ )
+ defer.returnValue([])
+
+ ret = []
+ for r in response:
+ if _is_valid_3pe_result(r, field=required_field):
+ ret.append(r)
+ else:
+ logger.warning(
+ "query_3pe to %s returned an invalid result %r",
+ uri, r
+ )
+
+ defer.returnValue(ret)
except Exception as ex:
logger.warning("query_3pe to %s threw exception %s", uri, ex)
defer.returnValue([])
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 6f162a3c00..18dca462a9 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -18,7 +18,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure
from synapse.util.logcontext import preserve_fn
-from synapse.types import ThirdPartyEntityKind
import logging
@@ -36,28 +35,6 @@ def log_failure(failure):
)
-def _is_valid_3pentity_result(r, field):
- if not isinstance(r, dict):
- return False
-
- for k in (field, "protocol"):
- if k not in r:
- return False
- if not isinstance(r[k], str):
- return False
-
- if "fields" not in r:
- return False
- fields = r["fields"]
- if not isinstance(fields, dict):
- return False
- for k in fields.keys():
- if not isinstance(fields[k], str):
- return False
-
- return True
-
-
class ApplicationServicesHandler(object):
def __init__(self, hs):
@@ -178,29 +155,10 @@ class ApplicationServicesHandler(object):
for service in services
], consumeErrors=True)
- required_field = (
- "userid" if kind == ThirdPartyEntityKind.USER else
- "alias" if kind == ThirdPartyEntityKind.LOCATION else
- None
- )
-
ret = []
for (success, result) in results:
- if not success:
- logger.warn("Application service failed %r", result)
- continue
- if not isinstance(result, list):
- logger.warn(
- "Application service returned an invalid response %r", result
- )
- continue
- for r in result:
- if _is_valid_3pentity_result(r, field=required_field):
- ret.append(r)
- else:
- logger.warn(
- "Application service returned an invalid result %r", r
- )
+ if success:
+ ret.extend(result)
defer.returnValue(ret)
|