diff --git a/changelog.d/46.feature b/changelog.d/46.feature
new file mode 100644
index 0000000000..7872d956e3
--- /dev/null
+++ b/changelog.d/46.feature
@@ -0,0 +1 @@
+Add a bulk version of the User Info API. Deprecate the single-use version.
\ No newline at end of file
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 060bf07197..3d47af0a8b 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
from six.moves import urllib
@@ -1021,6 +1021,20 @@ class TransportLayerClient(object):
return self.client.get_json(destination=destination, path=path)
+ def get_info_of_users(self, destination: str, user_ids: List[str]):
+ """
+ Args:
+ destination: The remote server
+ user_ids: A list of user IDs to query info about
+
+ Returns:
+ Deferred[List]: A dictionary of User ID to information about that user.
+ """
+ path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/users/info")
+ data = {"user_ids": user_ids}
+
+ return self.client.post_json(destination=destination, path=path, data=data)
+
def _create_path(federation_prefix, path, *args):
"""
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index af4595498c..cb6331d613 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -33,6 +33,7 @@ from synapse.api.urls import (
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
+ assert_params_in_dict,
parse_boolean_from_args,
parse_integer_from_args,
parse_json_object_from_request,
@@ -849,6 +850,57 @@ class PublicRoomList(BaseFederationServlet):
return 200, data
+class FederationUserInfoServlet(BaseFederationServlet):
+ """
+ Return information about a set of users.
+
+ This API returns expiration and deactivation information about a set of
+ users. Requested users not local to this homeserver will be ignored.
+
+ Example request:
+ POST /users/info
+
+ {
+ "user_ids": [
+ "@alice:example.com",
+ "@bob:example.com"
+ ]
+ }
+
+ Example response
+ {
+ "@alice:example.com": {
+ "expired": false,
+ "deactivated": true
+ }
+ }
+ """
+
+ PATH = "/users/info"
+ PREFIX = FEDERATION_UNSTABLE_PREFIX
+
+ def __init__(self, handler, authenticator, ratelimiter, server_name):
+ super(FederationUserInfoServlet, self).__init__(
+ handler, authenticator, ratelimiter, server_name
+ )
+ self.handler = handler
+
+ async def on_POST(self, origin, content, query):
+ assert_params_in_dict(content, required=["user_ids"])
+
+ user_ids = content.get("user_ids", [])
+
+ if not isinstance(user_ids, list):
+ raise SynapseError(
+ 400,
+ "'user_ids' must be a list of user ID strings",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ data = await self.handler.store.get_info_for_users(user_ids)
+ return 200, data
+
+
class FederationVersionServlet(BaseFederationServlet):
PATH = "/version"
@@ -1410,6 +1462,7 @@ FEDERATION_SERVLET_CLASSES = (
On3pidBindServlet,
FederationVersionServlet,
RoomComplexityServlet,
+ FederationUserInfoServlet,
) # type: Tuple[Type[BaseFederationServlet], ...]
OPENID_SERVLET_CLASSES = (
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index f9dfdce112..6e8300d6a5 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -14,13 +14,16 @@
# limitations under the License.
import logging
+from typing import Dict
from signedjson.sign import sign_json
-from twisted.internet import defer
-
-from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
from synapse.types import UserID
from ._base import client_patterns
@@ -92,45 +95,43 @@ class UserDirectorySearchRestServlet(RestServlet):
return 200, results
-class UserInfoServlet(RestServlet):
+class SingleUserInfoServlet(RestServlet):
"""
+ Deprecated and replaced by `/users/info`
+
GET /user/{user_id}/info HTTP/1.1
"""
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/info$")
def __init__(self, hs):
- super(UserInfoServlet, self).__init__()
+ super(SingleUserInfoServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- self.notifier = hs.get_notifier()
- self.clock = hs.get_clock()
self.transport_layer = hs.get_federation_transport_client()
registry = hs.get_federation_registry()
if not registry.query_handlers.get("user_info"):
registry.register_query_handler("user_info", self._on_federation_query)
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
+ async def on_GET(self, request, user_id):
# Ensure the user is authenticated
- yield self.auth.get_user_by_req(request, allow_guest=False)
+ await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if not self.hs.is_mine(user):
# Attempt to make a federation request to the server that owns this user
args = {"user_id": user_id}
- res = yield self.transport_layer.make_query(
+ res = await self.transport_layer.make_query(
user.domain, "user_info", args, retry_on_dns_fail=True
)
- defer.returnValue((200, res))
+ return 200, res
- res = yield self._get_user_info(user_id)
- defer.returnValue((200, res))
+ user_id_to_info = await self.store.get_info_for_users([user_id])
+ return 200, user_id_to_info[user_id]
- @defer.inlineCallbacks
- def _on_federation_query(self, args):
+ async def _on_federation_query(self, args):
"""Called when a request for user information appears over federation
Args:
@@ -147,32 +148,72 @@ class UserInfoServlet(RestServlet):
if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this homeserver")
- res = yield self._get_user_info(user_id)
- defer.returnValue(res)
+ user_ids_to_info_dict = await self.store.get_info_for_users([user_id])
+ return user_ids_to_info_dict[user_id]
- @defer.inlineCallbacks
- def _get_user_info(self, user_id):
- """Retrieve information about a given user
- Args:
- user_id (str): The User ID of a given user on this homeserver
+class UserInfoServlet(RestServlet):
+ """Bulk version of `/user/{user_id}/info` endpoint
- Returns:
- Deferred[dict]: Deactivation and expiration information for a given user
- """
- # Check whether user is deactivated
- is_deactivated = yield self.store.get_user_deactivated_status(user_id)
+ GET /users/info HTTP/1.1
- # Check whether user is expired
- expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
- is_expired = (
- expiration_ts is not None and self.clock.time_msec() >= expiration_ts
- )
+ Returns a dictionary of user_id to info dictionary. Supports remote users
+ """
+
+ PATTERNS = client_patterns("/users/info$", unstable=True, releases=())
+
+ def __init__(self, hs):
+ super(UserInfoServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.transport_layer = hs.get_federation_transport_client()
+
+ async def on_POST(self, request):
+ # Ensure the user is authenticated
+ await self.auth.get_user_by_req(request)
+
+ # Extract the user_ids from the request
+ body = parse_json_object_from_request(request)
+ assert_params_in_dict(body, required=["user_ids"])
+
+ user_ids = body["user_ids"]
+ if not isinstance(user_ids, list):
+ raise SynapseError(
+ 400,
+ "'user_ids' must be a list of user ID strings",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ # Separate local and remote users
+ local_user_ids = set()
+ remote_server_to_user_ids = {} # type: Dict[str, set]
+ for user_id in user_ids:
+ user = UserID.from_string(user_id)
+
+ if self.hs.is_mine(user):
+ local_user_ids.add(user_id)
+ else:
+ remote_server_to_user_ids.setdefault(user.domain, set())
+ remote_server_to_user_ids[user.domain].add(user_id)
+
+ # Retrieve info of all local users
+ user_id_to_info_dict = await self.store.get_info_for_users(local_user_ids)
+
+ # Request info of each remote user from their remote homeserver
+ for server_name, user_id_set in remote_server_to_user_ids.items():
+ # Make a request to the given server about their own users
+ res = await self.transport_layer.get_info_of_users(
+ server_name, list(user_id_set)
+ )
+
+ for user_id, info in res:
+ user_id_to_info_dict[user_id] = info
- res = {"expired": is_expired, "deactivated": is_deactivated}
- defer.returnValue(res)
+ return 200, user_id_to_info_dict
def register_servlets(hs, http_server):
UserDirectorySearchRestServlet(hs).register(http_server)
+ SingleUserInfoServlet(hs).register(http_server)
UserInfoServlet(hs).register(http_server)
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index b07c44d87a..1f1a7b4e36 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -17,6 +17,7 @@
import logging
import re
+from typing import List
from six import iterkeys
@@ -304,6 +305,55 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="delete_account_validity_for_user",
)
+ @defer.inlineCallbacks
+ def get_info_for_users(
+ self, user_ids: List[str],
+ ):
+ """Return the user info for a given set of users
+
+ Args:
+ user_ids: A list of users to return information about
+
+ Returns:
+ Deferred[Dict[str, bool]]: A dictionary mapping each user ID to
+ a dict with the following keys:
+ * expired - whether this is an expired user
+ * deactivated - whether this is a deactivated user
+ """
+ # Get information of all our local users
+ def _get_info_for_users_txn(txn):
+ rows = []
+
+ for user_id in user_ids:
+ sql = """
+ SELECT u.name, u.deactivated, av.expiration_ts_ms
+ FROM users as u
+ LEFT JOIN account_validity as av
+ ON av.user_id = u.name
+ WHERE u.name = ?
+ """
+
+ txn.execute(sql, (user_id,))
+ row = txn.fetchone()
+ if row:
+ rows.append(row)
+
+ return rows
+
+ info_rows = yield self.db.runInteraction(
+ "get_info_for_users", _get_info_for_users_txn
+ )
+
+ return {
+ user_id: {
+ "expired": (
+ expiration is not None and self.clock.time_msec() >= expiration
+ ),
+ "deactivated": deactivated == 1,
+ }
+ for user_id, deactivated, expiration in info_rows
+ }
+
async def is_server_admin(self, user):
"""Determines if a user is an admin of this homeserver.
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 572df8d80b..9f54b55acd 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -17,7 +17,7 @@ from mock import Mock
import synapse.rest.admin
from synapse.api.constants import UserTypes
from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import user_directory
+from synapse.rest.client.v2_alpha import account, account_validity, user_directory
from synapse.storage.roommember import ProfileInfo
from tests import unittest
@@ -460,3 +460,136 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["results"]) == 0)
+
+
+class UserInfoTestCase(unittest.FederatingHomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ account_validity.register_servlets,
+ synapse.rest.client.v2_alpha.user_directory.register_servlets,
+ account.register_servlets,
+ ]
+
+ def default_config(self):
+ config = super().default_config()
+
+ # Set accounts to expire after a week
+ config["account_validity"] = {
+ "enabled": True,
+ "period": 604800000, # Time in ms for 1 week
+ }
+ return config
+
+ def prepare(self, reactor, clock, hs):
+ super(UserInfoTestCase, self).prepare(reactor, clock, hs)
+ self.store = hs.get_datastore()
+ self.handler = hs.get_user_directory_handler()
+
+ def test_user_info(self):
+ """Test /users/info for local users from the Client-Server API"""
+ user_one, user_two, user_three, user_three_token = self.setup_test_users()
+
+ # Request info about each user from user_three
+ request, channel = self.make_request(
+ "POST",
+ path="/_matrix/client/unstable/users/info",
+ content={"user_ids": [user_one, user_two, user_three]},
+ access_token=user_three_token,
+ shorthand=False,
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code, channel.result)
+
+ # Check the state of user_one matches
+ user_one_info = channel.json_body[user_one]
+ self.assertTrue(user_one_info["deactivated"])
+ self.assertFalse(user_one_info["expired"])
+
+ # Check the state of user_two matches
+ user_two_info = channel.json_body[user_two]
+ self.assertFalse(user_two_info["deactivated"])
+ self.assertTrue(user_two_info["expired"])
+
+ # Check the state of user_three matches
+ user_three_info = channel.json_body[user_three]
+ self.assertFalse(user_three_info["deactivated"])
+ self.assertFalse(user_three_info["expired"])
+
+ def test_user_info_federation(self):
+ """Test that /users/info can be called from the Federation API, and
+ and that we can query remote users from the Client-Server API
+ """
+ user_one, user_two, user_three, user_three_token = self.setup_test_users()
+
+ # Request information about our local users from the perspective of a remote server
+ request, channel = self.make_request(
+ "POST",
+ path="/_matrix/federation/unstable/users/info",
+ content={"user_ids": [user_one, user_two, user_three]},
+ )
+ self.render(request)
+ self.assertEquals(200, channel.code)
+
+ # Check the state of user_one matches
+ user_one_info = channel.json_body[user_one]
+ self.assertTrue(user_one_info["deactivated"])
+ self.assertFalse(user_one_info["expired"])
+
+ # Check the state of user_two matches
+ user_two_info = channel.json_body[user_two]
+ self.assertFalse(user_two_info["deactivated"])
+ self.assertTrue(user_two_info["expired"])
+
+ # Check the state of user_three matches
+ user_three_info = channel.json_body[user_three]
+ self.assertFalse(user_three_info["deactivated"])
+ self.assertFalse(user_three_info["expired"])
+
+ def setup_test_users(self):
+ """Create an admin user and three test users, each with a different state"""
+
+ # Create an admin user to expire other users with
+ self.register_user("admin", "adminpassword", admin=True)
+ admin_token = self.login("admin", "adminpassword")
+
+ # Create three users
+ user_one = self.register_user("alice", "pass")
+ user_one_token = self.login("alice", "pass")
+ user_two = self.register_user("bob", "pass")
+ user_three = self.register_user("carl", "pass")
+ user_three_token = self.login("carl", "pass")
+
+ # Deactivate user_one
+ self.deactivate(user_one, user_one_token)
+
+ # Expire user_two
+ self.expire(user_two, admin_token)
+
+ # Do nothing to user_three
+
+ return user_one, user_two, user_three, user_three_token
+
+ def expire(self, user_id_to_expire, admin_tok):
+ url = "/_matrix/client/unstable/admin/account_validity/validity"
+ request_data = {
+ "user_id": user_id_to_expire,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=admin_tok
+ )
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ def deactivate(self, user_id, tok):
+ request_data = {
+ "auth": {"type": "m.login.password", "user": user_id, "password": "pass"},
+ "erase": False,
+ }
+ request, channel = self.make_request(
+ "POST", "account/deactivate", request_data, access_token=tok
+ )
+ self.render(request)
+ self.assertEqual(request.code, 200)
|