summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
authorAndrew Morgan <1342360+anoadragon453@users.noreply.github.com>2020-06-19 16:17:13 +0100
committerGitHub <noreply@github.com>2020-06-19 16:17:13 +0100
commit53949a908fc785ab6c27b40343040df2462cd309 (patch)
tree23feb3417a2bc47d0d9c1ab3be13c383ea07a261 /synapse/rest
parentPerformance improvements to marking expired users as inactive (#47) (diff)
downloadsynapse-53949a908fc785ab6c27b40343040df2462cd309.tar.xz
Add a bulk user info endpoint and deprecate the old one (#46)
The current `/user/<user_id>/info` API was useful in that it could be used by any user to lookup whether another user was deactivate or expired. However, it was impractical as it only allowed for a single lookup at once. Clients trying to use this API were met with speed issues as they tried to query this information for all users in a room.

This PR adds an equivalent CS and Federation API that takes a list of user IDs, and returning a mapping from user ID to info dictionary.

Note that the federation in this PR was a bit trickier than in the original #12 as we can no longer use a federation query, as those don't allow for JSON bodies - which we require to pass a list of user IDs. Instead we do the whole thing of adding a method to transport/client and transport/server.

This PR also adds unittests. The earlier PR used Sytest, presumably for testing across federation, but as this is Synapse-specific that felt a little gross. Unit tests for the deprecated endpoint have not been added.
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/client/v2_alpha/user_directory.py113
1 files changed, 77 insertions, 36 deletions
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index f9dfdce112..6e8300d6a5 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -14,13 +14,16 @@
 # limitations under the License.
 
 import logging
+from typing import Dict
 
 from signedjson.sign import sign_json
 
-from twisted.internet import defer
-
-from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import (
+    RestServlet,
+    assert_params_in_dict,
+    parse_json_object_from_request,
+)
 from synapse.types import UserID
 
 from ._base import client_patterns
@@ -92,45 +95,43 @@ class UserDirectorySearchRestServlet(RestServlet):
         return 200, results
 
 
-class UserInfoServlet(RestServlet):
+class SingleUserInfoServlet(RestServlet):
     """
+    Deprecated and replaced by `/users/info`
+
     GET /user/{user_id}/info HTTP/1.1
     """
 
     PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/info$")
 
     def __init__(self, hs):
-        super(UserInfoServlet, self).__init__()
+        super(SingleUserInfoServlet, self).__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
-        self.notifier = hs.get_notifier()
-        self.clock = hs.get_clock()
         self.transport_layer = hs.get_federation_transport_client()
         registry = hs.get_federation_registry()
 
         if not registry.query_handlers.get("user_info"):
             registry.register_query_handler("user_info", self._on_federation_query)
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, user_id):
+    async def on_GET(self, request, user_id):
         # Ensure the user is authenticated
-        yield self.auth.get_user_by_req(request, allow_guest=False)
+        await self.auth.get_user_by_req(request)
 
         user = UserID.from_string(user_id)
         if not self.hs.is_mine(user):
             # Attempt to make a federation request to the server that owns this user
             args = {"user_id": user_id}
-            res = yield self.transport_layer.make_query(
+            res = await self.transport_layer.make_query(
                 user.domain, "user_info", args, retry_on_dns_fail=True
             )
-            defer.returnValue((200, res))
+            return 200, res
 
-        res = yield self._get_user_info(user_id)
-        defer.returnValue((200, res))
+        user_id_to_info = await self.store.get_info_for_users([user_id])
+        return 200, user_id_to_info[user_id]
 
-    @defer.inlineCallbacks
-    def _on_federation_query(self, args):
+    async def _on_federation_query(self, args):
         """Called when a request for user information appears over federation
 
         Args:
@@ -147,32 +148,72 @@ class UserInfoServlet(RestServlet):
         if not self.hs.is_mine(user):
             raise SynapseError(400, "User is not hosted on this homeserver")
 
-        res = yield self._get_user_info(user_id)
-        defer.returnValue(res)
+        user_ids_to_info_dict = await self.store.get_info_for_users([user_id])
+        return user_ids_to_info_dict[user_id]
 
-    @defer.inlineCallbacks
-    def _get_user_info(self, user_id):
-        """Retrieve information about a given user
 
-        Args:
-            user_id (str): The User ID of a given user on this homeserver
+class UserInfoServlet(RestServlet):
+    """Bulk version of `/user/{user_id}/info` endpoint
 
-        Returns:
-            Deferred[dict]: Deactivation and expiration information for a given user
-        """
-        # Check whether user is deactivated
-        is_deactivated = yield self.store.get_user_deactivated_status(user_id)
+    GET /users/info HTTP/1.1
 
-        # Check whether user is expired
-        expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
-        is_expired = (
-            expiration_ts is not None and self.clock.time_msec() >= expiration_ts
-        )
+    Returns a dictionary of user_id to info dictionary. Supports remote users
+    """
+
+    PATTERNS = client_patterns("/users/info$", unstable=True, releases=())
+
+    def __init__(self, hs):
+        super(UserInfoServlet, self).__init__()
+        self.hs = hs
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
+        self.transport_layer = hs.get_federation_transport_client()
+
+    async def on_POST(self, request):
+        # Ensure the user is authenticated
+        await self.auth.get_user_by_req(request)
+
+        # Extract the user_ids from the request
+        body = parse_json_object_from_request(request)
+        assert_params_in_dict(body, required=["user_ids"])
+
+        user_ids = body["user_ids"]
+        if not isinstance(user_ids, list):
+            raise SynapseError(
+                400,
+                "'user_ids' must be a list of user ID strings",
+                errcode=Codes.INVALID_PARAM,
+            )
+
+        # Separate local and remote users
+        local_user_ids = set()
+        remote_server_to_user_ids = {}  # type: Dict[str, set]
+        for user_id in user_ids:
+            user = UserID.from_string(user_id)
+
+            if self.hs.is_mine(user):
+                local_user_ids.add(user_id)
+            else:
+                remote_server_to_user_ids.setdefault(user.domain, set())
+                remote_server_to_user_ids[user.domain].add(user_id)
+
+        # Retrieve info of all local users
+        user_id_to_info_dict = await self.store.get_info_for_users(local_user_ids)
+
+        # Request info of each remote user from their remote homeserver
+        for server_name, user_id_set in remote_server_to_user_ids.items():
+            # Make a request to the given server about their own users
+            res = await self.transport_layer.get_info_of_users(
+                server_name, list(user_id_set)
+            )
+
+            for user_id, info in res:
+                user_id_to_info_dict[user_id] = info
 
-        res = {"expired": is_expired, "deactivated": is_deactivated}
-        defer.returnValue(res)
+        return 200, user_id_to_info_dict
 
 
 def register_servlets(hs, http_server):
     UserDirectorySearchRestServlet(hs).register(http_server)
+    SingleUserInfoServlet(hs).register(http_server)
     UserInfoServlet(hs).register(http_server)