diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index cc4af23962..b0eb0c6d9d 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -32,6 +32,14 @@ HOUR_IN_MS = 60 * 60 * 1000
APP_SERVICE_PREFIX = "/_matrix/app/unstable"
+def _is_valid_3pe_metadata(info):
+ if "instances" not in info:
+ return False
+ if not isinstance(info["instances"], list):
+ return False
+ return True
+
+
def _is_valid_3pe_result(r, field):
if not isinstance(r, dict):
return False
@@ -162,11 +170,18 @@ class ApplicationServiceApi(SimpleHttpClient):
urllib.quote(protocol)
)
try:
- defer.returnValue((yield self.get_json(uri, {})))
+ info = yield self.get_json(uri, {})
+
+ if not _is_valid_3pe_metadata(info):
+ logger.warning("query_3pe_protocol to %s did not return a"
+ " valid result", uri)
+ defer.returnValue(None)
+
+ defer.returnValue(info)
except Exception as ex:
logger.warning("query_3pe_protocol to %s threw exception %s",
uri, ex)
- defer.returnValue({})
+ defer.returnValue(None)
key = (service.id, protocol)
return self.protocol_meta_cache.get(key) or (
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index b440280b74..88fa0bb2e4 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -176,12 +176,41 @@ class ApplicationServicesHandler(object):
defer.returnValue(ret)
@defer.inlineCallbacks
- def get_3pe_protocols(self):
+ def get_3pe_protocols(self, only_protocol=None):
services = yield self.store.get_app_services()
protocols = {}
+
+ # Collect up all the individual protocol responses out of the ASes
for s in services:
for p in s.protocols:
- protocols[p] = yield self.appservice_api.get_3pe_protocol(s, p)
+ if only_protocol is not None and p != only_protocol:
+ continue
+
+ if p not in protocols:
+ protocols[p] = []
+
+ info = yield self.appservice_api.get_3pe_protocol(s, p)
+
+ if info is not None:
+ protocols[p].append(info)
+
+ def _merge_instances(infos):
+ if not infos:
+ return {}
+
+ # Merge the 'instances' lists of multiple results, but just take
+ # the other fields from the first as they ought to be identical
+ # copy the result so as not to corrupt the cached one
+ combined = dict(infos[0])
+ combined["instances"] = list(combined["instances"])
+
+ for info in infos[1:]:
+ combined["instances"].extend(info["instances"])
+
+ return combined
+
+ for p in protocols.keys():
+ protocols[p] = _merge_instances(protocols[p])
defer.returnValue(protocols)
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index 4f6f1a7e17..dca615927a 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -42,6 +42,29 @@ class ThirdPartyProtocolsServlet(RestServlet):
defer.returnValue((200, protocols))
+class ThirdPartyProtocolServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$",
+ releases=())
+
+ def __init__(self, hs):
+ super(ThirdPartyProtocolServlet, self).__init__()
+
+ self.auth = hs.get_auth()
+ self.appservice_handler = hs.get_application_service_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, protocol):
+ yield self.auth.get_user_by_req(request)
+
+ protocols = yield self.appservice_handler.get_3pe_protocols(
+ only_protocol=protocol,
+ )
+ if protocol in protocols:
+ defer.returnValue((200, protocols[protocol]))
+ else:
+ defer.returnValue((404, {"error": "Unknown protocol"}))
+
+
class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
releases=())
@@ -92,5 +115,6 @@ class ThirdPartyLocationServlet(RestServlet):
def register_servlets(hs, http_server):
ThirdPartyProtocolsServlet(hs).register(http_server)
+ ThirdPartyProtocolServlet(hs).register(http_server)
ThirdPartyUserServlet(hs).register(http_server)
ThirdPartyLocationServlet(hs).register(http_server)
|