summary refs log tree commit diff
path: root/synapse/rest/admin
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/admin')
-rw-r--r--synapse/rest/admin/rooms.py17
-rw-r--r--synapse/rest/admin/users.py26
2 files changed, 28 insertions, 15 deletions
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index f0cddd2d2c..40ee33646c 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -402,9 +402,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
 
         # Get the room ID from the identifier.
         try:
-            remote_room_hosts = [
+            remote_room_hosts: Optional[List[str]] = [
                 x.decode("ascii") for x in request.args[b"server_name"]
-            ]  # type: Optional[List[str]]
+            ]
         except Exception:
             remote_room_hosts = None
         room_id, remote_room_hosts = await self.resolve_room_id(
@@ -462,6 +462,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
         super().__init__(hs)
         self.hs = hs
         self.auth = hs.get_auth()
+        self.store = hs.get_datastore()
         self.event_creation_handler = hs.get_event_creation_handler()
         self.state_handler = hs.get_state_handler()
         self.is_mine_id = hs.is_mine_id
@@ -500,7 +501,13 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
             admin_user_id = None
 
             for admin_user in reversed(admin_users):
-                if room_state.get((EventTypes.Member, admin_user)):
+                (
+                    current_membership_type,
+                    _,
+                ) = await self.store.get_local_current_membership_for_user_in_room(
+                    admin_user, room_id
+                )
+                if current_membership_type == "join":
                     admin_user_id = admin_user
                     break
 
@@ -652,9 +659,7 @@ class RoomEventContextServlet(RestServlet):
         filter_str = parse_string(request, "filter", encoding="utf-8")
         if filter_str:
             filter_json = urlparse.unquote(filter_str)
-            event_filter = Filter(
-                json_decoder.decode(filter_json)
-            )  # type: Optional[Filter]
+            event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
         else:
             event_filter = None
 
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 7d75564758..589e47fa47 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -357,7 +357,7 @@ class UserRegisterServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         self.auth_handler = hs.get_auth_handler()
         self.reactor = hs.get_reactor()
-        self.nonces = {}  # type: Dict[str, int]
+        self.nonces: Dict[str, int] = {}
         self.hs = hs
 
     def _clear_old_nonces(self):
@@ -560,16 +560,24 @@ class AccountValidityRenewServlet(RestServlet):
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        body = parse_json_object_from_request(request)
+        if self.account_activity_handler.on_legacy_admin_request_callback:
+            expiration_ts = await (
+                self.account_activity_handler.on_legacy_admin_request_callback(request)
+            )
+        else:
+            body = parse_json_object_from_request(request)
 
-        if "user_id" not in body:
-            raise SynapseError(400, "Missing property 'user_id' in the request body")
+            if "user_id" not in body:
+                raise SynapseError(
+                    400,
+                    "Missing property 'user_id' in the request body",
+                )
 
-        expiration_ts = await self.account_activity_handler.renew_account_for_user(
-            body["user_id"],
-            body.get("expiration_ts"),
-            not body.get("enable_renewal_emails", True),
-        )
+            expiration_ts = await self.account_activity_handler.renew_account_for_user(
+                body["user_id"],
+                body.get("expiration_ts"),
+                not body.get("enable_renewal_emails", True),
+            )
 
         res = {"expiration_ts": expiration_ts}
         return 200, res