diff options
author | Paul "LeoNerd" Evans <paul@matrix.org> | 2016-08-18 17:33:56 +0100 |
---|---|---|
committer | Paul "LeoNerd" Evans <paul@matrix.org> | 2016-08-18 17:33:56 +0100 |
commit | 65201631a407b71087bb52da8b591e0975c463ec (patch) | |
tree | e7077e567fddbc9a0ac30204e19dc1f0bc9e74e2 | |
parent | More warnings about invalid results from AS 3PE query (diff) | |
download | synapse-65201631a407b71087bb52da8b591e0975c463ec.tar.xz |
Move validation logic for AS 3PE query response into ApplicationServiceApi class, to keep the handler logic neater
Diffstat (limited to '')
-rw-r--r-- | synapse/appservice/api.py | 43 | ||||
-rw-r--r-- | synapse/handlers/appservice.py | 46 |
2 files changed, 44 insertions, 45 deletions
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index dd5e762e0d..066127b666 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -25,6 +25,28 @@ import urllib logger = logging.getLogger(__name__) +def _is_valid_3pe_result(r, field): + if not isinstance(r, dict): + return False + + for k in (field, "protocol"): + if k not in r: + return False + if not isinstance(r[k], str): + return False + + if "fields" not in r: + return False + fields = r["fields"] + if not isinstance(fields, dict): + return False + for k in fields.keys(): + if not isinstance(fields[k], str): + return False + + return True + + class ApplicationServiceApi(SimpleHttpClient): """This class manages HS -> AS communications, including querying and pushing. @@ -76,8 +98,10 @@ class ApplicationServiceApi(SimpleHttpClient): def query_3pe(self, service, kind, protocol, fields): if kind == ThirdPartyEntityKind.USER: uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol)) + required_field = "userid" elif kind == ThirdPartyEntityKind.LOCATION: uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol)) + required_field = "alias" else: raise ValueError( "Unrecognised 'kind' argument %r to query_3pe()", kind @@ -85,7 +109,24 @@ class ApplicationServiceApi(SimpleHttpClient): try: response = yield self.get_json(uri, fields) - defer.returnValue(response) + if not isinstance(response, list): + logger.warning( + "query_3pe to %s returned an invalid response %r", + uri, response + ) + defer.returnValue([]) + + ret = [] + for r in response: + if _is_valid_3pe_result(r, field=required_field): + ret.append(r) + else: + logger.warning( + "query_3pe to %s returned an invalid result %r", + uri, r + ) + + defer.returnValue(ret) except Exception as ex: logger.warning("query_3pe to %s threw exception %s", uri, ex) defer.returnValue([]) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 6f162a3c00..18dca462a9 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -18,7 +18,6 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.util.metrics import Measure from synapse.util.logcontext import preserve_fn -from synapse.types import ThirdPartyEntityKind import logging @@ -36,28 +35,6 @@ def log_failure(failure): ) -def _is_valid_3pentity_result(r, field): - if not isinstance(r, dict): - return False - - for k in (field, "protocol"): - if k not in r: - return False - if not isinstance(r[k], str): - return False - - if "fields" not in r: - return False - fields = r["fields"] - if not isinstance(fields, dict): - return False - for k in fields.keys(): - if not isinstance(fields[k], str): - return False - - return True - - class ApplicationServicesHandler(object): def __init__(self, hs): @@ -178,29 +155,10 @@ class ApplicationServicesHandler(object): for service in services ], consumeErrors=True) - required_field = ( - "userid" if kind == ThirdPartyEntityKind.USER else - "alias" if kind == ThirdPartyEntityKind.LOCATION else - None - ) - ret = [] for (success, result) in results: - if not success: - logger.warn("Application service failed %r", result) - continue - if not isinstance(result, list): - logger.warn( - "Application service returned an invalid response %r", result - ) - continue - for r in result: - if _is_valid_3pentity_result(r, field=required_field): - ret.append(r) - else: - logger.warn( - "Application service returned an invalid result %r", r - ) + if success: + ret.extend(result) defer.returnValue(ret) |