diff options
author | Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> | 2020-06-19 16:17:13 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-06-19 16:17:13 +0100 |
commit | 53949a908fc785ab6c27b40343040df2462cd309 (patch) | |
tree | 23feb3417a2bc47d0d9c1ab3be13c383ea07a261 /synapse/rest | |
parent | Performance improvements to marking expired users as inactive (#47) (diff) | |
download | synapse-53949a908fc785ab6c27b40343040df2462cd309.tar.xz |
Add a bulk user info endpoint and deprecate the old one (#46)
The current `/user/<user_id>/info` API was useful in that it could be used by any user to lookup whether another user was deactivate or expired. However, it was impractical as it only allowed for a single lookup at once. Clients trying to use this API were met with speed issues as they tried to query this information for all users in a room. This PR adds an equivalent CS and Federation API that takes a list of user IDs, and returning a mapping from user ID to info dictionary. Note that the federation in this PR was a bit trickier than in the original #12 as we can no longer use a federation query, as those don't allow for JSON bodies - which we require to pass a list of user IDs. Instead we do the whole thing of adding a method to transport/client and transport/server. This PR also adds unittests. The earlier PR used Sytest, presumably for testing across federation, but as this is Synapse-specific that felt a little gross. Unit tests for the deprecated endpoint have not been added.
Diffstat (limited to 'synapse/rest')
-rw-r--r-- | synapse/rest/client/v2_alpha/user_directory.py | 113 |
1 files changed, 77 insertions, 36 deletions
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py index f9dfdce112..6e8300d6a5 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -14,13 +14,16 @@ # limitations under the License. import logging +from typing import Dict from signedjson.sign import sign_json -from twisted.internet import defer - -from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) from synapse.types import UserID from ._base import client_patterns @@ -92,45 +95,43 @@ class UserDirectorySearchRestServlet(RestServlet): return 200, results -class UserInfoServlet(RestServlet): +class SingleUserInfoServlet(RestServlet): """ + Deprecated and replaced by `/users/info` + GET /user/{user_id}/info HTTP/1.1 """ PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/info$") def __init__(self, hs): - super(UserInfoServlet, self).__init__() + super(SingleUserInfoServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() - self.notifier = hs.get_notifier() - self.clock = hs.get_clock() self.transport_layer = hs.get_federation_transport_client() registry = hs.get_federation_registry() if not registry.query_handlers.get("user_info"): registry.register_query_handler("user_info", self._on_federation_query) - @defer.inlineCallbacks - def on_GET(self, request, user_id): + async def on_GET(self, request, user_id): # Ensure the user is authenticated - yield self.auth.get_user_by_req(request, allow_guest=False) + await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): # Attempt to make a federation request to the server that owns this user args = {"user_id": user_id} - res = yield self.transport_layer.make_query( + res = await self.transport_layer.make_query( user.domain, "user_info", args, retry_on_dns_fail=True ) - defer.returnValue((200, res)) + return 200, res - res = yield self._get_user_info(user_id) - defer.returnValue((200, res)) + user_id_to_info = await self.store.get_info_for_users([user_id]) + return 200, user_id_to_info[user_id] - @defer.inlineCallbacks - def _on_federation_query(self, args): + async def _on_federation_query(self, args): """Called when a request for user information appears over federation Args: @@ -147,32 +148,72 @@ class UserInfoServlet(RestServlet): if not self.hs.is_mine(user): raise SynapseError(400, "User is not hosted on this homeserver") - res = yield self._get_user_info(user_id) - defer.returnValue(res) + user_ids_to_info_dict = await self.store.get_info_for_users([user_id]) + return user_ids_to_info_dict[user_id] - @defer.inlineCallbacks - def _get_user_info(self, user_id): - """Retrieve information about a given user - Args: - user_id (str): The User ID of a given user on this homeserver +class UserInfoServlet(RestServlet): + """Bulk version of `/user/{user_id}/info` endpoint - Returns: - Deferred[dict]: Deactivation and expiration information for a given user - """ - # Check whether user is deactivated - is_deactivated = yield self.store.get_user_deactivated_status(user_id) + GET /users/info HTTP/1.1 - # Check whether user is expired - expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) - is_expired = ( - expiration_ts is not None and self.clock.time_msec() >= expiration_ts - ) + Returns a dictionary of user_id to info dictionary. Supports remote users + """ + + PATTERNS = client_patterns("/users/info$", unstable=True, releases=()) + + def __init__(self, hs): + super(UserInfoServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.transport_layer = hs.get_federation_transport_client() + + async def on_POST(self, request): + # Ensure the user is authenticated + await self.auth.get_user_by_req(request) + + # Extract the user_ids from the request + body = parse_json_object_from_request(request) + assert_params_in_dict(body, required=["user_ids"]) + + user_ids = body["user_ids"] + if not isinstance(user_ids, list): + raise SynapseError( + 400, + "'user_ids' must be a list of user ID strings", + errcode=Codes.INVALID_PARAM, + ) + + # Separate local and remote users + local_user_ids = set() + remote_server_to_user_ids = {} # type: Dict[str, set] + for user_id in user_ids: + user = UserID.from_string(user_id) + + if self.hs.is_mine(user): + local_user_ids.add(user_id) + else: + remote_server_to_user_ids.setdefault(user.domain, set()) + remote_server_to_user_ids[user.domain].add(user_id) + + # Retrieve info of all local users + user_id_to_info_dict = await self.store.get_info_for_users(local_user_ids) + + # Request info of each remote user from their remote homeserver + for server_name, user_id_set in remote_server_to_user_ids.items(): + # Make a request to the given server about their own users + res = await self.transport_layer.get_info_of_users( + server_name, list(user_id_set) + ) + + for user_id, info in res: + user_id_to_info_dict[user_id] = info - res = {"expired": is_expired, "deactivated": is_deactivated} - defer.returnValue(res) + return 200, user_id_to_info_dict def register_servlets(hs, http_server): UserDirectorySearchRestServlet(hs).register(http_server) + SingleUserInfoServlet(hs).register(http_server) UserInfoServlet(hs).register(http_server) |