diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index d4cad1b1ed..dd5e762e0d 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event
+from synapse.types import ThirdPartyEntityKind
import logging
import urllib
@@ -72,25 +73,21 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue(False)
@defer.inlineCallbacks
- def query_3pu(self, service, protocol, fields):
- uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
- response = None
- try:
- response = yield self.get_json(uri, fields)
- defer.returnValue(response)
- except Exception as ex:
- logger.warning("query_3pu to %s threw exception %s", uri, ex)
- defer.returnValue([])
+ def query_3pe(self, service, kind, protocol, fields):
+ if kind == ThirdPartyEntityKind.USER:
+ uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
+ elif kind == ThirdPartyEntityKind.LOCATION:
+ uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
+ else:
+ raise ValueError(
+ "Unrecognised 'kind' argument %r to query_3pe()", kind
+ )
- @defer.inlineCallbacks
- def query_3pl(self, service, protocol, fields):
- uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
- response = None
try:
response = yield self.get_json(uri, fields)
defer.returnValue(response)
except Exception as ex:
- logger.warning("query_3pl to %s threw exception %s", uri, ex)
+ logger.warning("query_3pe to %s threw exception %s", uri, ex)
defer.returnValue([])
@defer.inlineCallbacks
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 03452f6bb0..52c127d2c1 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure
from synapse.util.logcontext import preserve_fn
+from synapse.types import ThirdPartyEntityKind
import logging
@@ -169,37 +170,19 @@ class ApplicationServicesHandler(object):
defer.returnValue(result)
@defer.inlineCallbacks
- def query_3pu(self, protocol, fields):
+ def query_3pe(self, kind, protocol, fields):
services = yield self._get_services_for_3pn(protocol)
results = yield defer.DeferredList([
- self.appservice_api.query_3pu(service, protocol, fields)
+ self.appservice_api.query_3pe(service, kind, protocol, fields)
for service in services
], consumeErrors=True)
- ret = []
- for (success, result) in results:
- if not success:
- continue
- if not isinstance(result, list):
- continue
- for r in result:
- if _is_valid_3pentity_result(r, field="userid"):
- ret.append(r)
- else:
- logger.warn("Application service returned an " +
- "invalid result %r", r)
-
- defer.returnValue(ret)
-
- @defer.inlineCallbacks
- def query_3pl(self, protocol, fields):
- services = yield self._get_services_for_3pn(protocol)
-
- results = yield defer.DeferredList([
- self.appservice_api.query_3pl(service, protocol, fields)
- for service in services
- ], consumeErrors=True)
+ required_field = (
+ "userid" if kind == ThirdPartyEntityKind.USER else
+ "alias" if kind == ThirdPartyEntityKind.LOCATION else
+ None
+ )
ret = []
for (success, result) in results:
@@ -208,7 +191,7 @@ class ApplicationServicesHandler(object):
if not isinstance(result, list):
continue
for r in result:
- if _is_valid_3pentity_result(r, field="alias"):
+ if _is_valid_3pentity_result(r, field=required_field):
ret.append(r)
else:
logger.warn("Application service returned an " +
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index d229e4b818..9abca3a8ad 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -19,6 +19,7 @@ import logging
from twisted.internet import defer
from synapse.http.servlet import RestServlet
+from synapse.types import ThirdPartyEntityKind
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
@@ -41,7 +42,9 @@ class ThirdPartyUserServlet(RestServlet):
fields = request.args
del fields["access_token"]
- results = yield self.appservice_handler.query_3pu(protocol, fields)
+ results = yield self.appservice_handler.query_3pe(
+ ThirdPartyEntityKind.USER, protocol, fields
+ )
defer.returnValue((200, results))
@@ -63,7 +66,9 @@ class ThirdPartyLocationServlet(RestServlet):
fields = request.args
del fields["access_token"]
- results = yield self.appservice_handler.query_3pl(protocol, fields)
+ results = yield self.appservice_handler.query_3pe(
+ ThirdPartyEntityKind.LOCATION, protocol, fields
+ )
defer.returnValue((200, results))
diff --git a/synapse/types.py b/synapse/types.py
index 5349b0c450..fd17ecbbe0 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -269,3 +269,10 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "t%d-%d" % (self.topological, self.stream)
else:
return "s%d" % (self.stream,)
+
+
+# Some arbitrary constants used for internal API enumerations. Don't rely on
+# exact values; always pass or compare symbolically
+class ThirdPartyEntityKind(object):
+ USER = 'user'
+ LOCATION = 'location'
|