summary refs log tree commit diff
diff options
context:
space:
mode:
authorAndrew Morgan <1342360+anoadragon453@users.noreply.github.com>2020-06-19 16:17:13 +0100
committerGitHub <noreply@github.com>2020-06-19 16:17:13 +0100
commit53949a908fc785ab6c27b40343040df2462cd309 (patch)
tree23feb3417a2bc47d0d9c1ab3be13c383ea07a261
parentPerformance improvements to marking expired users as inactive (#47) (diff)
downloadsynapse-53949a908fc785ab6c27b40343040df2462cd309.tar.xz
Add a bulk user info endpoint and deprecate the old one (#46)
The current `/user/<user_id>/info` API was useful in that it could be used by any user to lookup whether another user was deactivate or expired. However, it was impractical as it only allowed for a single lookup at once. Clients trying to use this API were met with speed issues as they tried to query this information for all users in a room.

This PR adds an equivalent CS and Federation API that takes a list of user IDs, and returning a mapping from user ID to info dictionary.

Note that the federation in this PR was a bit trickier than in the original #12 as we can no longer use a federation query, as those don't allow for JSON bodies - which we require to pass a list of user IDs. Instead we do the whole thing of adding a method to transport/client and transport/server.

This PR also adds unittests. The earlier PR used Sytest, presumably for testing across federation, but as this is Synapse-specific that felt a little gross. Unit tests for the deprecated endpoint have not been added.
-rw-r--r--changelog.d/46.feature1
-rw-r--r--synapse/federation/transport/client.py16
-rw-r--r--synapse/federation/transport/server.py53
-rw-r--r--synapse/rest/client/v2_alpha/user_directory.py113
-rw-r--r--synapse/storage/data_stores/main/registration.py50
-rw-r--r--tests/handlers/test_user_directory.py135
6 files changed, 330 insertions, 38 deletions
diff --git a/changelog.d/46.feature b/changelog.d/46.feature
new file mode 100644

index 0000000000..7872d956e3 --- /dev/null +++ b/changelog.d/46.feature
@@ -0,0 +1 @@ +Add a bulk version of the User Info API. Deprecate the single-use version. \ No newline at end of file diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 060bf07197..3d47af0a8b 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py
@@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from six.moves import urllib @@ -1021,6 +1021,20 @@ class TransportLayerClient(object): return self.client.get_json(destination=destination, path=path) + def get_info_of_users(self, destination: str, user_ids: List[str]): + """ + Args: + destination: The remote server + user_ids: A list of user IDs to query info about + + Returns: + Deferred[List]: A dictionary of User ID to information about that user. + """ + path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/users/info") + data = {"user_ids": user_ids} + + return self.client.post_json(destination=destination, path=path, data=data) + def _create_path(federation_prefix, path, *args): """ diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index af4595498c..cb6331d613 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py
@@ -33,6 +33,7 @@ from synapse.api.urls import ( from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.server import JsonResource from synapse.http.servlet import ( + assert_params_in_dict, parse_boolean_from_args, parse_integer_from_args, parse_json_object_from_request, @@ -849,6 +850,57 @@ class PublicRoomList(BaseFederationServlet): return 200, data +class FederationUserInfoServlet(BaseFederationServlet): + """ + Return information about a set of users. + + This API returns expiration and deactivation information about a set of + users. Requested users not local to this homeserver will be ignored. + + Example request: + POST /users/info + + { + "user_ids": [ + "@alice:example.com", + "@bob:example.com" + ] + } + + Example response + { + "@alice:example.com": { + "expired": false, + "deactivated": true + } + } + """ + + PATH = "/users/info" + PREFIX = FEDERATION_UNSTABLE_PREFIX + + def __init__(self, handler, authenticator, ratelimiter, server_name): + super(FederationUserInfoServlet, self).__init__( + handler, authenticator, ratelimiter, server_name + ) + self.handler = handler + + async def on_POST(self, origin, content, query): + assert_params_in_dict(content, required=["user_ids"]) + + user_ids = content.get("user_ids", []) + + if not isinstance(user_ids, list): + raise SynapseError( + 400, + "'user_ids' must be a list of user ID strings", + errcode=Codes.INVALID_PARAM, + ) + + data = await self.handler.store.get_info_for_users(user_ids) + return 200, data + + class FederationVersionServlet(BaseFederationServlet): PATH = "/version" @@ -1410,6 +1462,7 @@ FEDERATION_SERVLET_CLASSES = ( On3pidBindServlet, FederationVersionServlet, RoomComplexityServlet, + FederationUserInfoServlet, ) # type: Tuple[Type[BaseFederationServlet], ...] OPENID_SERVLET_CLASSES = ( diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index f9dfdce112..6e8300d6a5 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -14,13 +14,16 @@ # limitations under the License. import logging +from typing import Dict from signedjson.sign import sign_json -from twisted.internet import defer - -from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) from synapse.types import UserID from ._base import client_patterns @@ -92,45 +95,43 @@ class UserDirectorySearchRestServlet(RestServlet): return 200, results -class UserInfoServlet(RestServlet): +class SingleUserInfoServlet(RestServlet): """ + Deprecated and replaced by `/users/info` + GET /user/{user_id}/info HTTP/1.1 """ PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/info$") def __init__(self, hs): - super(UserInfoServlet, self).__init__() + super(SingleUserInfoServlet, self).__init__() self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() - self.notifier = hs.get_notifier() - self.clock = hs.get_clock() self.transport_layer = hs.get_federation_transport_client() registry = hs.get_federation_registry() if not registry.query_handlers.get("user_info"): registry.register_query_handler("user_info", self._on_federation_query) - @defer.inlineCallbacks - def on_GET(self, request, user_id): + async def on_GET(self, request, user_id): # Ensure the user is authenticated - yield self.auth.get_user_by_req(request, allow_guest=False) + await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): # Attempt to make a federation request to the server that owns this user args = {"user_id": user_id} - res = yield self.transport_layer.make_query( + res = await self.transport_layer.make_query( user.domain, "user_info", args, retry_on_dns_fail=True ) - defer.returnValue((200, res)) + return 200, res - res = yield self._get_user_info(user_id) - defer.returnValue((200, res)) + user_id_to_info = await self.store.get_info_for_users([user_id]) + return 200, user_id_to_info[user_id] - @defer.inlineCallbacks - def _on_federation_query(self, args): + async def _on_federation_query(self, args): """Called when a request for user information appears over federation Args: @@ -147,32 +148,72 @@ class UserInfoServlet(RestServlet): if not self.hs.is_mine(user): raise SynapseError(400, "User is not hosted on this homeserver") - res = yield self._get_user_info(user_id) - defer.returnValue(res) + user_ids_to_info_dict = await self.store.get_info_for_users([user_id]) + return user_ids_to_info_dict[user_id] - @defer.inlineCallbacks - def _get_user_info(self, user_id): - """Retrieve information about a given user - Args: - user_id (str): The User ID of a given user on this homeserver +class UserInfoServlet(RestServlet): + """Bulk version of `/user/{user_id}/info` endpoint - Returns: - Deferred[dict]: Deactivation and expiration information for a given user - """ - # Check whether user is deactivated - is_deactivated = yield self.store.get_user_deactivated_status(user_id) + GET /users/info HTTP/1.1 - # Check whether user is expired - expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) - is_expired = ( - expiration_ts is not None and self.clock.time_msec() >= expiration_ts - ) + Returns a dictionary of user_id to info dictionary. Supports remote users + """ + + PATTERNS = client_patterns("/users/info$", unstable=True, releases=()) + + def __init__(self, hs): + super(UserInfoServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.transport_layer = hs.get_federation_transport_client() + + async def on_POST(self, request): + # Ensure the user is authenticated + await self.auth.get_user_by_req(request) + + # Extract the user_ids from the request + body = parse_json_object_from_request(request) + assert_params_in_dict(body, required=["user_ids"]) + + user_ids = body["user_ids"] + if not isinstance(user_ids, list): + raise SynapseError( + 400, + "'user_ids' must be a list of user ID strings", + errcode=Codes.INVALID_PARAM, + ) + + # Separate local and remote users + local_user_ids = set() + remote_server_to_user_ids = {} # type: Dict[str, set] + for user_id in user_ids: + user = UserID.from_string(user_id) + + if self.hs.is_mine(user): + local_user_ids.add(user_id) + else: + remote_server_to_user_ids.setdefault(user.domain, set()) + remote_server_to_user_ids[user.domain].add(user_id) + + # Retrieve info of all local users + user_id_to_info_dict = await self.store.get_info_for_users(local_user_ids) + + # Request info of each remote user from their remote homeserver + for server_name, user_id_set in remote_server_to_user_ids.items(): + # Make a request to the given server about their own users + res = await self.transport_layer.get_info_of_users( + server_name, list(user_id_set) + ) + + for user_id, info in res: + user_id_to_info_dict[user_id] = info - res = {"expired": is_expired, "deactivated": is_deactivated} - defer.returnValue(res) + return 200, user_id_to_info_dict def register_servlets(hs, http_server): UserDirectorySearchRestServlet(hs).register(http_server) + SingleUserInfoServlet(hs).register(http_server) UserInfoServlet(hs).register(http_server) diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index b07c44d87a..1f1a7b4e36 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py
@@ -17,6 +17,7 @@ import logging import re +from typing import List from six import iterkeys @@ -304,6 +305,55 @@ class RegistrationWorkerStore(SQLBaseStore): desc="delete_account_validity_for_user", ) + @defer.inlineCallbacks + def get_info_for_users( + self, user_ids: List[str], + ): + """Return the user info for a given set of users + + Args: + user_ids: A list of users to return information about + + Returns: + Deferred[Dict[str, bool]]: A dictionary mapping each user ID to + a dict with the following keys: + * expired - whether this is an expired user + * deactivated - whether this is a deactivated user + """ + # Get information of all our local users + def _get_info_for_users_txn(txn): + rows = [] + + for user_id in user_ids: + sql = """ + SELECT u.name, u.deactivated, av.expiration_ts_ms + FROM users as u + LEFT JOIN account_validity as av + ON av.user_id = u.name + WHERE u.name = ? + """ + + txn.execute(sql, (user_id,)) + row = txn.fetchone() + if row: + rows.append(row) + + return rows + + info_rows = yield self.db.runInteraction( + "get_info_for_users", _get_info_for_users_txn + ) + + return { + user_id: { + "expired": ( + expiration is not None and self.clock.time_msec() >= expiration + ), + "deactivated": deactivated == 1, + } + for user_id, deactivated, expiration in info_rows + } + async def is_server_admin(self, user): """Determines if a user is an admin of this homeserver. diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 572df8d80b..9f54b55acd 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py
@@ -17,7 +17,7 @@ from mock import Mock import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import user_directory +from synapse.rest.client.v2_alpha import account, account_validity, user_directory from synapse.storage.roommember import ProfileInfo from tests import unittest @@ -460,3 +460,136 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["results"]) == 0) + + +class UserInfoTestCase(unittest.FederatingHomeserverTestCase): + servlets = [ + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + account_validity.register_servlets, + synapse.rest.client.v2_alpha.user_directory.register_servlets, + account.register_servlets, + ] + + def default_config(self): + config = super().default_config() + + # Set accounts to expire after a week + config["account_validity"] = { + "enabled": True, + "period": 604800000, # Time in ms for 1 week + } + return config + + def prepare(self, reactor, clock, hs): + super(UserInfoTestCase, self).prepare(reactor, clock, hs) + self.store = hs.get_datastore() + self.handler = hs.get_user_directory_handler() + + def test_user_info(self): + """Test /users/info for local users from the Client-Server API""" + user_one, user_two, user_three, user_three_token = self.setup_test_users() + + # Request info about each user from user_three + request, channel = self.make_request( + "POST", + path="/_matrix/client/unstable/users/info", + content={"user_ids": [user_one, user_two, user_three]}, + access_token=user_three_token, + shorthand=False, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + # Check the state of user_one matches + user_one_info = channel.json_body[user_one] + self.assertTrue(user_one_info["deactivated"]) + self.assertFalse(user_one_info["expired"]) + + # Check the state of user_two matches + user_two_info = channel.json_body[user_two] + self.assertFalse(user_two_info["deactivated"]) + self.assertTrue(user_two_info["expired"]) + + # Check the state of user_three matches + user_three_info = channel.json_body[user_three] + self.assertFalse(user_three_info["deactivated"]) + self.assertFalse(user_three_info["expired"]) + + def test_user_info_federation(self): + """Test that /users/info can be called from the Federation API, and + and that we can query remote users from the Client-Server API + """ + user_one, user_two, user_three, user_three_token = self.setup_test_users() + + # Request information about our local users from the perspective of a remote server + request, channel = self.make_request( + "POST", + path="/_matrix/federation/unstable/users/info", + content={"user_ids": [user_one, user_two, user_three]}, + ) + self.render(request) + self.assertEquals(200, channel.code) + + # Check the state of user_one matches + user_one_info = channel.json_body[user_one] + self.assertTrue(user_one_info["deactivated"]) + self.assertFalse(user_one_info["expired"]) + + # Check the state of user_two matches + user_two_info = channel.json_body[user_two] + self.assertFalse(user_two_info["deactivated"]) + self.assertTrue(user_two_info["expired"]) + + # Check the state of user_three matches + user_three_info = channel.json_body[user_three] + self.assertFalse(user_three_info["deactivated"]) + self.assertFalse(user_three_info["expired"]) + + def setup_test_users(self): + """Create an admin user and three test users, each with a different state""" + + # Create an admin user to expire other users with + self.register_user("admin", "adminpassword", admin=True) + admin_token = self.login("admin", "adminpassword") + + # Create three users + user_one = self.register_user("alice", "pass") + user_one_token = self.login("alice", "pass") + user_two = self.register_user("bob", "pass") + user_three = self.register_user("carl", "pass") + user_three_token = self.login("carl", "pass") + + # Deactivate user_one + self.deactivate(user_one, user_one_token) + + # Expire user_two + self.expire(user_two, admin_token) + + # Do nothing to user_three + + return user_one, user_two, user_three, user_three_token + + def expire(self, user_id_to_expire, admin_tok): + url = "/_matrix/client/unstable/admin/account_validity/validity" + request_data = { + "user_id": user_id_to_expire, + "expiration_ts": 0, + "enable_renewal_emails": False, + } + request, channel = self.make_request( + "POST", url, request_data, access_token=admin_tok + ) + self.render(request) + self.assertEquals(channel.result["code"], b"200", channel.result) + + def deactivate(self, user_id, tok): + request_data = { + "auth": {"type": "m.login.password", "user": user_id, "password": "pass"}, + "erase": False, + } + request, channel = self.make_request( + "POST", "account/deactivate", request_data, access_token=tok + ) + self.render(request) + self.assertEqual(request.code, 200)