diff options
Diffstat (limited to '')
-rw-r--r-- | synapse/appservice/api.py | 19 | ||||
-rw-r--r-- | synapse/handlers/appservice.py | 33 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/thirdparty.py | 24 |
3 files changed, 72 insertions, 4 deletions
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index cc4af23962..b0eb0c6d9d 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -32,6 +32,14 @@ HOUR_IN_MS = 60 * 60 * 1000 APP_SERVICE_PREFIX = "/_matrix/app/unstable" +def _is_valid_3pe_metadata(info): + if "instances" not in info: + return False + if not isinstance(info["instances"], list): + return False + return True + + def _is_valid_3pe_result(r, field): if not isinstance(r, dict): return False @@ -162,11 +170,18 @@ class ApplicationServiceApi(SimpleHttpClient): urllib.quote(protocol) ) try: - defer.returnValue((yield self.get_json(uri, {}))) + info = yield self.get_json(uri, {}) + + if not _is_valid_3pe_metadata(info): + logger.warning("query_3pe_protocol to %s did not return a" + " valid result", uri) + defer.returnValue(None) + + defer.returnValue(info) except Exception as ex: logger.warning("query_3pe_protocol to %s threw exception %s", uri, ex) - defer.returnValue({}) + defer.returnValue(None) key = (service.id, protocol) return self.protocol_meta_cache.get(key) or ( diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index b440280b74..88fa0bb2e4 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -176,12 +176,41 @@ class ApplicationServicesHandler(object): defer.returnValue(ret) @defer.inlineCallbacks - def get_3pe_protocols(self): + def get_3pe_protocols(self, only_protocol=None): services = yield self.store.get_app_services() protocols = {} + + # Collect up all the individual protocol responses out of the ASes for s in services: for p in s.protocols: - protocols[p] = yield self.appservice_api.get_3pe_protocol(s, p) + if only_protocol is not None and p != only_protocol: + continue + + if p not in protocols: + protocols[p] = [] + + info = yield self.appservice_api.get_3pe_protocol(s, p) + + if info is not None: + protocols[p].append(info) + + def _merge_instances(infos): + if not infos: + return {} + + # Merge the 'instances' lists of multiple results, but just take + # the other fields from the first as they ought to be identical + # copy the result so as not to corrupt the cached one + combined = dict(infos[0]) + combined["instances"] = list(combined["instances"]) + + for info in infos[1:]: + combined["instances"].extend(info["instances"]) + + return combined + + for p in protocols.keys(): + protocols[p] = _merge_instances(protocols[p]) defer.returnValue(protocols) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 4f6f1a7e17..dca615927a 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -42,6 +42,29 @@ class ThirdPartyProtocolsServlet(RestServlet): defer.returnValue((200, protocols)) +class ThirdPartyProtocolServlet(RestServlet): + PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$", + releases=()) + + def __init__(self, hs): + super(ThirdPartyProtocolServlet, self).__init__() + + self.auth = hs.get_auth() + self.appservice_handler = hs.get_application_service_handler() + + @defer.inlineCallbacks + def on_GET(self, request, protocol): + yield self.auth.get_user_by_req(request) + + protocols = yield self.appservice_handler.get_3pe_protocols( + only_protocol=protocol, + ) + if protocol in protocols: + defer.returnValue((200, protocols[protocol])) + else: + defer.returnValue((404, {"error": "Unknown protocol"})) + + class ThirdPartyUserServlet(RestServlet): PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$", releases=()) @@ -92,5 +115,6 @@ class ThirdPartyLocationServlet(RestServlet): def register_servlets(hs, http_server): ThirdPartyProtocolsServlet(hs).register(http_server) + ThirdPartyProtocolServlet(hs).register(http_server) ThirdPartyUserServlet(hs).register(http_server) ThirdPartyLocationServlet(hs).register(http_server) |