summary refs log tree commit diff
path: root/synapse/rest/client/v2_alpha/groups.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-12-30 08:39:59 -0500
committerGitHub <noreply@github.com>2020-12-30 08:39:59 -0500
commitb7c580e33341ffbd12533a572439778929f811bf (patch)
tree41a68e5c7b8f4e4992c7aa0fd16e8d5be0e456a5 /synapse/rest/client/v2_alpha/groups.py
parentAdd additional type hints to the storage module. (#8980) (diff)
downloadsynapse-b7c580e33341ffbd12533a572439778929f811bf.tar.xz
Check if group IDs are valid before using them. (#8977)
Diffstat (limited to 'synapse/rest/client/v2_alpha/groups.py')
-rw-r--r--synapse/rest/client/v2_alpha/groups.py48
1 files changed, 45 insertions, 3 deletions
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index a3bb095c2d..5b5da71815 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+from functools import wraps
 
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -25,6 +26,22 @@ from ._base import client_patterns
 logger = logging.getLogger(__name__)
 
 
+def _validate_group_id(f):
+    """Wrapper to validate the form of the group ID.
+
+    Can be applied to any on_FOO methods that accepts a group ID as a URL parameter.
+    """
+
+    @wraps(f)
+    def wrapper(self, request, group_id, *args, **kwargs):
+        if not GroupID.is_valid(group_id):
+            raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
+
+        return f(self, request, group_id, *args, **kwargs)
+
+    return wrapper
+
+
 class GroupServlet(RestServlet):
     """Get the group profile
     """
@@ -37,6 +54,7 @@ class GroupServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -47,6 +65,7 @@ class GroupServlet(RestServlet):
 
         return 200, group_description
 
+    @_validate_group_id
     async def on_POST(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -71,6 +90,7 @@ class GroupSummaryServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -102,6 +122,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, category_id, room_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -117,6 +138,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
 
         return 200, resp
 
+    @_validate_group_id
     async def on_DELETE(self, request, group_id, category_id, room_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -142,6 +164,7 @@ class GroupCategoryServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id, category_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -152,6 +175,7 @@ class GroupCategoryServlet(RestServlet):
 
         return 200, category
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, category_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -163,6 +187,7 @@ class GroupCategoryServlet(RestServlet):
 
         return 200, resp
 
+    @_validate_group_id
     async def on_DELETE(self, request, group_id, category_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -186,6 +211,7 @@ class GroupCategoriesServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -209,6 +235,7 @@ class GroupRoleServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id, role_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -219,6 +246,7 @@ class GroupRoleServlet(RestServlet):
 
         return 200, category
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, role_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -230,6 +258,7 @@ class GroupRoleServlet(RestServlet):
 
         return 200, resp
 
+    @_validate_group_id
     async def on_DELETE(self, request, group_id, role_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -253,6 +282,7 @@ class GroupRolesServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -284,6 +314,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, role_id, user_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -299,6 +330,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
 
         return 200, resp
 
+    @_validate_group_id
     async def on_DELETE(self, request, group_id, role_id, user_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -322,13 +354,11 @@ class GroupRoomServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
-        if not GroupID.is_valid(group_id):
-            raise SynapseError(400, "%s was not legal group ID" % (group_id,))
-
         result = await self.groups_handler.get_rooms_in_group(
             group_id, requester_user_id
         )
@@ -348,6 +378,7 @@ class GroupUsersServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -371,6 +402,7 @@ class GroupInvitedUsersServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -393,6 +425,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
         self.auth = hs.get_auth()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -449,6 +482,7 @@ class GroupAdminRoomsServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, room_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -460,6 +494,7 @@ class GroupAdminRoomsServlet(RestServlet):
 
         return 200, result
 
+    @_validate_group_id
     async def on_DELETE(self, request, group_id, room_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -486,6 +521,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, room_id, config_key):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -514,6 +550,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
         self.store = hs.get_datastore()
         self.is_mine_id = hs.is_mine_id
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, user_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -541,6 +578,7 @@ class GroupAdminUsersKickServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, user_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -565,6 +603,7 @@ class GroupSelfLeaveServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -589,6 +628,7 @@ class GroupSelfJoinServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -613,6 +653,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -637,6 +678,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()