summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/rest/admin/rooms.py48
-rw-r--r--synapse/storage/databases/main/devices.py32
2 files changed, 64 insertions, 16 deletions
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 25f89e4685..b902af8028 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 from http import HTTPStatus
-from typing import List, Optional
+from typing import TYPE_CHECKING, List, Optional, Tuple
 
 from synapse.api.constants import EventTypes, JoinRules
 from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -25,13 +25,17 @@ from synapse.http.servlet import (
     parse_json_object_from_request,
     parse_string,
 )
+from synapse.http.site import SynapseRequest
 from synapse.rest.admin._base import (
     admin_patterns,
     assert_requester_is_admin,
     assert_user_is_admin,
 )
 from synapse.storage.databases.main.room import RoomSortOrder
-from synapse.types import RoomAlias, RoomID, UserID, create_requester
+from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -45,12 +49,14 @@ class ShutdownRoomRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.room_shutdown_handler = hs.get_room_shutdown_handler()
 
-    async def on_POST(self, request, room_id):
+    async def on_POST(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -86,13 +92,15 @@ class DeleteRoomRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.room_shutdown_handler = hs.get_room_shutdown_handler()
         self.pagination_handler = hs.get_pagination_handler()
 
-    async def on_POST(self, request, room_id):
+    async def on_POST(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -146,12 +154,12 @@ class ListRoomRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/rooms$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -236,19 +244,24 @@ class RoomRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
-    async def on_GET(self, request, room_id):
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         ret = await self.store.get_room_with_stats(room_id)
         if not ret:
             raise NotFoundError("Room not found")
 
-        return 200, ret
+        members = await self.store.get_users_in_room(room_id)
+        ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
+
+        return (200, ret)
 
 
 class RoomMembersRestServlet(RestServlet):
@@ -258,12 +271,14 @@ class RoomMembersRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
-    async def on_GET(self, request, room_id):
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         ret = await self.store.get_room(room_id)
@@ -280,14 +295,16 @@ class JoinRoomAliasServlet(RestServlet):
 
     PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.room_member_handler = hs.get_room_member_handler()
         self.admin_handler = hs.get_admin_handler()
         self.state_handler = hs.get_state_handler()
 
-    async def on_POST(self, request, room_identifier):
+    async def on_POST(
+        self, request: SynapseRequest, room_identifier: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -314,7 +331,6 @@ class JoinRoomAliasServlet(RestServlet):
             handler = self.room_member_handler
             room_alias = RoomAlias.from_string(room_identifier)
             room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
-            room_id = room_id.to_string()
         else:
             raise SynapseError(
                 400, "%s was not legal room ID or room alias" % (room_identifier,)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index dfb4f87b8f..9097677648 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -57,6 +57,38 @@ class DeviceWorkerStore(SQLBaseStore):
                 self._prune_old_outbound_device_pokes, 60 * 60 * 1000
             )
 
+    async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
+        """Retrieve number of all devices of given users.
+        Only returns number of devices that are not marked as hidden.
+
+        Args:
+            user_ids: The IDs of the users which owns devices
+        Returns:
+            Number of devices of this users.
+        """
+
+        def count_devices_by_users_txn(txn, user_ids):
+            sql = """
+                SELECT count(*)
+                FROM devices
+                WHERE
+                    hidden = '0' AND
+            """
+
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "user_id", user_ids
+            )
+
+            txn.execute(sql + clause, args)
+            return txn.fetchone()[0]
+
+        if not user_ids:
+            return 0
+
+        return await self.db_pool.runInteraction(
+            "count_devices_by_users", count_devices_by_users_txn, user_ids
+        )
+
     async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
         """Retrieve a device. Only returns devices that are not marked as
         hidden.