diff options
author | Paul "LeoNerd" Evans <paul@matrix.org> | 2016-08-18 17:19:55 +0100 |
---|---|---|
committer | Paul "LeoNerd" Evans <paul@matrix.org> | 2016-08-18 17:19:55 +0100 |
commit | b515f844ee07c7d6aa1d7e56faa8b65d282e9341 (patch) | |
tree | 7d154e0c825e687fe19eed42998182aa07847d1b /synapse/handlers/appservice.py | |
parent | Minor syntax neatenings (diff) | |
download | synapse-b515f844ee07c7d6aa1d7e56faa8b65d282e9341.tar.xz |
Avoid so much copypasta between 3PU and 3PL query by unifying around a ThirdPartyEntityKind enumeration
Diffstat (limited to '')
-rw-r--r-- | synapse/handlers/appservice.py | 35 |
1 files changed, 9 insertions, 26 deletions
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 03452f6bb0..52c127d2c1 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.util.metrics import Measure from synapse.util.logcontext import preserve_fn +from synapse.types import ThirdPartyEntityKind import logging @@ -169,37 +170,19 @@ class ApplicationServicesHandler(object): defer.returnValue(result) @defer.inlineCallbacks - def query_3pu(self, protocol, fields): + def query_3pe(self, kind, protocol, fields): services = yield self._get_services_for_3pn(protocol) results = yield defer.DeferredList([ - self.appservice_api.query_3pu(service, protocol, fields) + self.appservice_api.query_3pe(service, kind, protocol, fields) for service in services ], consumeErrors=True) - ret = [] - for (success, result) in results: - if not success: - continue - if not isinstance(result, list): - continue - for r in result: - if _is_valid_3pentity_result(r, field="userid"): - ret.append(r) - else: - logger.warn("Application service returned an " + - "invalid result %r", r) - - defer.returnValue(ret) - - @defer.inlineCallbacks - def query_3pl(self, protocol, fields): - services = yield self._get_services_for_3pn(protocol) - - results = yield defer.DeferredList([ - self.appservice_api.query_3pl(service, protocol, fields) - for service in services - ], consumeErrors=True) + required_field = ( + "userid" if kind == ThirdPartyEntityKind.USER else + "alias" if kind == ThirdPartyEntityKind.LOCATION else + None + ) ret = [] for (success, result) in results: @@ -208,7 +191,7 @@ class ApplicationServicesHandler(object): if not isinstance(result, list): continue for r in result: - if _is_valid_3pentity_result(r, field="alias"): + if _is_valid_3pentity_result(r, field=required_field): ret.append(r) else: logger.warn("Application service returned an " + |