summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/events/spamcheck.py5
-rw-r--r--synapse/handlers/directory.py59
-rw-r--r--synapse/handlers/identity.py9
-rw-r--r--synapse/handlers/message.py24
-rw-r--r--synapse/handlers/presence.py159
-rw-r--r--synapse/handlers/room_member.py2
-rw-r--r--synapse/handlers/ui_auth/checkers.py35
-rw-r--r--synapse/rest/admin/rooms.py134
-rw-r--r--synapse/storage/database.py4
-rw-r--r--synapse/util/caches/lrucache.py77
10 files changed, 321 insertions, 187 deletions
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 7118d5f52d..d5fa195094 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple,
 from synapse.rest.media.v1._base import FileInfo
 from synapse.rest.media.v1.media_storage import ReadableFileWrapper
 from synapse.spam_checker_api import RegistrationBehaviour
+from synapse.types import RoomAlias
 from synapse.util.async_helpers import maybe_awaitable
 
 if TYPE_CHECKING:
@@ -113,7 +114,9 @@ class SpamChecker:
 
         return True
 
-    async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
+    async def user_may_create_room_alias(
+        self, userid: str, room_alias: RoomAlias
+    ) -> bool:
         """Checks if a given user may create a room alias
 
         If this method returns false, the association request will be rejected.
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 90932316f3..de1b14cde3 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -14,7 +14,7 @@
 
 import logging
 import string
-from typing import Iterable, List, Optional
+from typing import TYPE_CHECKING, Iterable, List, Optional
 
 from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes
 from synapse.api.errors import (
@@ -27,15 +27,19 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.appservice import ApplicationService
-from synapse.types import Requester, RoomAlias, UserID, get_domain_from_id
+from synapse.storage.databases.main.directory import RoomAliasMapping
+from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class DirectoryHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.state = hs.get_state_handler()
@@ -60,7 +64,7 @@ class DirectoryHandler(BaseHandler):
         room_id: str,
         servers: Optional[Iterable[str]] = None,
         creator: Optional[str] = None,
-    ):
+    ) -> None:
         # general association creation for both human users and app services
 
         for wchar in string.whitespace:
@@ -104,8 +108,9 @@ class DirectoryHandler(BaseHandler):
         """
 
         user_id = requester.user.to_string()
+        room_alias_str = room_alias.to_string()
 
-        if len(room_alias.to_string()) > MAX_ALIAS_LENGTH:
+        if len(room_alias_str) > MAX_ALIAS_LENGTH:
             raise SynapseError(
                 400,
                 "Can't create aliases longer than %s characters" % MAX_ALIAS_LENGTH,
@@ -114,7 +119,7 @@ class DirectoryHandler(BaseHandler):
 
         service = requester.app_service
         if service:
-            if not service.is_interested_in_alias(room_alias.to_string()):
+            if not service.is_interested_in_alias(room_alias_str):
                 raise SynapseError(
                     400,
                     "This application service has not reserved this kind of alias.",
@@ -138,7 +143,7 @@ class DirectoryHandler(BaseHandler):
                 raise AuthError(403, "This user is not permitted to create this alias")
 
             if not self.config.is_alias_creation_allowed(
-                user_id, room_id, room_alias.to_string()
+                user_id, room_id, room_alias_str
             ):
                 # Lets just return a generic message, as there may be all sorts of
                 # reasons why we said no. TODO: Allow configurable error messages
@@ -211,7 +216,7 @@ class DirectoryHandler(BaseHandler):
 
     async def delete_appservice_association(
         self, service: ApplicationService, room_alias: RoomAlias
-    ):
+    ) -> None:
         if not service.is_interested_in_alias(room_alias.to_string()):
             raise SynapseError(
                 400,
@@ -220,7 +225,7 @@ class DirectoryHandler(BaseHandler):
             )
         await self._delete_association(room_alias)
 
-    async def _delete_association(self, room_alias: RoomAlias):
+    async def _delete_association(self, room_alias: RoomAlias) -> str:
         if not self.hs.is_mine(room_alias):
             raise SynapseError(400, "Room alias must be local")
 
@@ -228,17 +233,19 @@ class DirectoryHandler(BaseHandler):
 
         return room_id
 
-    async def get_association(self, room_alias: RoomAlias):
+    async def get_association(self, room_alias: RoomAlias) -> JsonDict:
         room_id = None
         if self.hs.is_mine(room_alias):
-            result = await self.get_association_from_room_alias(room_alias)
+            result = await self.get_association_from_room_alias(
+                room_alias
+            )  # type: Optional[RoomAliasMapping]
 
             if result:
                 room_id = result.room_id
                 servers = result.servers
         else:
             try:
-                result = await self.federation.make_query(
+                fed_result = await self.federation.make_query(
                     destination=room_alias.domain,
                     query_type="directory",
                     args={"room_alias": room_alias.to_string()},
@@ -248,13 +255,13 @@ class DirectoryHandler(BaseHandler):
             except CodeMessageException as e:
                 logging.warning("Error retrieving alias")
                 if e.code == 404:
-                    result = None
+                    fed_result = None
                 else:
                     raise
 
-            if result and "room_id" in result and "servers" in result:
-                room_id = result["room_id"]
-                servers = result["servers"]
+            if fed_result and "room_id" in fed_result and "servers" in fed_result:
+                room_id = fed_result["room_id"]
+                servers = fed_result["servers"]
 
         if not room_id:
             raise SynapseError(
@@ -275,7 +282,7 @@ class DirectoryHandler(BaseHandler):
 
         return {"room_id": room_id, "servers": servers}
 
-    async def on_directory_query(self, args):
+    async def on_directory_query(self, args: JsonDict) -> JsonDict:
         room_alias = RoomAlias.from_string(args["room_alias"])
         if not self.hs.is_mine(room_alias):
             raise SynapseError(400, "Room Alias is not hosted on this homeserver")
@@ -293,7 +300,7 @@ class DirectoryHandler(BaseHandler):
 
     async def _update_canonical_alias(
         self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias
-    ):
+    ) -> None:
         """
         Send an updated canonical alias event if the removed alias was set as
         the canonical alias or listed in the alt_aliases field.
@@ -344,7 +351,9 @@ class DirectoryHandler(BaseHandler):
                 ratelimit=False,
             )
 
-    async def get_association_from_room_alias(self, room_alias: RoomAlias):
+    async def get_association_from_room_alias(
+        self, room_alias: RoomAlias
+    ) -> Optional[RoomAliasMapping]:
         result = await self.store.get_association_from_room_alias(room_alias)
         if not result:
             # Query AS to see if it exists
@@ -372,7 +381,7 @@ class DirectoryHandler(BaseHandler):
         # either no interested services, or no service with an exclusive lock
         return True
 
-    async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
+    async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str) -> bool:
         """Determine whether a user can delete an alias.
 
         One of the following must be true:
@@ -394,14 +403,13 @@ class DirectoryHandler(BaseHandler):
         if not room_id:
             return False
 
-        res = await self.auth.check_can_change_room_list(
+        return await self.auth.check_can_change_room_list(
             room_id, UserID.from_string(user_id)
         )
-        return res
 
     async def edit_published_room_list(
         self, requester: Requester, room_id: str, visibility: str
-    ):
+    ) -> None:
         """Edit the entry of the room in the published room list.
 
         requester
@@ -469,7 +477,7 @@ class DirectoryHandler(BaseHandler):
 
     async def edit_published_appservice_room_list(
         self, appservice_id: str, network_id: str, room_id: str, visibility: str
-    ):
+    ) -> None:
         """Add or remove a room from the appservice/network specific public
         room list.
 
@@ -499,5 +507,4 @@ class DirectoryHandler(BaseHandler):
                 room_id, requester.user.to_string()
             )
 
-        aliases = await self.store.get_aliases_for_room(room_id)
-        return aliases
+        return await self.store.get_aliases_for_room(room_id)
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 0b3b1fadb5..33d16fbf9c 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -17,7 +17,7 @@
 """Utilities for interacting with Identity Servers"""
 import logging
 import urllib.parse
-from typing import Awaitable, Callable, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple
 
 from synapse.api.errors import (
     CodeMessageException,
@@ -41,13 +41,16 @@ from synapse.util.stringutils import (
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 id_server_scheme = "https://"
 
 
 class IdentityHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         # An HTTP client for contacting trusted URLs.
@@ -80,7 +83,7 @@ class IdentityHandler(BaseHandler):
         request: SynapseRequest,
         medium: str,
         address: str,
-    ):
+    ) -> None:
         """Used to ratelimit requests to `/requestToken` by IP and address.
 
         Args:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index ec8eb21674..49f8aa25ea 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 import logging
 import random
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
 
@@ -66,7 +66,7 @@ logger = logging.getLogger(__name__)
 class MessageHandler:
     """Contains some read only APIs to get state about a room"""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.auth = hs.get_auth()
         self.clock = hs.get_clock()
         self.state = hs.get_state_handler()
@@ -91,7 +91,7 @@ class MessageHandler:
         room_id: str,
         event_type: str,
         state_key: str,
-    ) -> dict:
+    ) -> Optional[EventBase]:
         """Get data from a room.
 
         Args:
@@ -115,6 +115,10 @@ class MessageHandler:
             data = await self.state.get_current_state(room_id, event_type, state_key)
         elif membership == Membership.LEAVE:
             key = (event_type, state_key)
+            # If the membership is not JOIN, then the event ID should exist.
+            assert (
+                membership_event_id is not None
+            ), "check_user_in_room_or_world_readable returned invalid data"
             room_state = await self.state_store.get_state_for_events(
                 [membership_event_id], StateFilter.from_types([key])
             )
@@ -186,10 +190,12 @@ class MessageHandler:
 
             event = last_events[0]
             if visible_events:
-                room_state = await self.state_store.get_state_for_events(
+                room_state_events = await self.state_store.get_state_for_events(
                     [event.event_id], state_filter=state_filter
                 )
-                room_state = room_state[event.event_id]
+                room_state = room_state_events[
+                    event.event_id
+                ]  # type: Mapping[Any, EventBase]
             else:
                 raise AuthError(
                     403,
@@ -210,10 +216,14 @@ class MessageHandler:
                 )
                 room_state = await self.store.get_events(state_ids.values())
             elif membership == Membership.LEAVE:
-                room_state = await self.state_store.get_state_for_events(
+                # If the membership is not JOIN, then the event ID should exist.
+                assert (
+                    membership_event_id is not None
+                ), "check_user_in_room_or_world_readable returned invalid data"
+                room_state_events = await self.state_store.get_state_for_events(
                     [membership_event_id], state_filter=state_filter
                 )
-                room_state = room_state[membership_event_id]
+                room_state = room_state_events[membership_event_id]
 
         now = self.clock.time_msec()
         events = await self._event_serializer.serialize_events(
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 12df35f26e..ebbc234334 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -28,6 +28,7 @@ from bisect import bisect
 from contextlib import contextmanager
 from typing import (
     TYPE_CHECKING,
+    Callable,
     Collection,
     Dict,
     FrozenSet,
@@ -232,23 +233,23 @@ class BasePresenceHandler(abc.ABC):
         """
 
     async def update_external_syncs_row(
-        self, process_id, user_id, is_syncing, sync_time_msec
-    ):
+        self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int
+    ) -> None:
         """Update the syncing users for an external process as a delta.
 
         This is a no-op when presence is handled by a different worker.
 
         Args:
-            process_id (str): An identifier for the process the users are
+            process_id: An identifier for the process the users are
                 syncing against. This allows synapse to process updates
                 as user start and stop syncing against a given process.
-            user_id (str): The user who has started or stopped syncing
-            is_syncing (bool): Whether or not the user is now syncing
-            sync_time_msec(int): Time in ms when the user was last syncing
+            user_id: The user who has started or stopped syncing
+            is_syncing: Whether or not the user is now syncing
+            sync_time_msec: Time in ms when the user was last syncing
         """
         pass
 
-    async def update_external_syncs_clear(self, process_id):
+    async def update_external_syncs_clear(self, process_id: str) -> None:
         """Marks all users that had been marked as syncing by a given process
         as offline.
 
@@ -304,7 +305,7 @@ class _NullContextManager(ContextManager[None]):
 
 
 class WorkerPresenceHandler(BasePresenceHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.hs = hs
 
@@ -327,7 +328,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
 
         # user_id -> last_sync_ms. Lists the users that have stopped syncing but
         # we haven't notified the presence writer of that yet
-        self.users_going_offline = {}
+        self.users_going_offline = {}  # type: Dict[str, int]
 
         self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs)
         self._set_state_client = ReplicationPresenceSetState.make_client(hs)
@@ -346,24 +347,21 @@ class WorkerPresenceHandler(BasePresenceHandler):
             self._on_shutdown,
         )
 
-    def _on_shutdown(self):
+    def _on_shutdown(self) -> None:
         if self._presence_enabled:
             self.hs.get_tcp_replication().send_command(
                 ClearUserSyncsCommand(self.instance_id)
             )
 
-    def send_user_sync(self, user_id, is_syncing, last_sync_ms):
+    def send_user_sync(self, user_id: str, is_syncing: bool, last_sync_ms: int) -> None:
         if self._presence_enabled:
             self.hs.get_tcp_replication().send_user_sync(
                 self.instance_id, user_id, is_syncing, last_sync_ms
             )
 
-    def mark_as_coming_online(self, user_id):
+    def mark_as_coming_online(self, user_id: str) -> None:
         """A user has started syncing. Send a UserSync to the presence writer,
         unless they had recently stopped syncing.
-
-        Args:
-            user_id (str)
         """
         going_offline = self.users_going_offline.pop(user_id, None)
         if not going_offline:
@@ -371,18 +369,15 @@ class WorkerPresenceHandler(BasePresenceHandler):
             # were offline
             self.send_user_sync(user_id, True, self.clock.time_msec())
 
-    def mark_as_going_offline(self, user_id):
+    def mark_as_going_offline(self, user_id: str) -> None:
         """A user has stopped syncing. We wait before notifying the presence
         writer as its likely they'll come back soon. This allows us to avoid
         sending a stopped syncing immediately followed by a started syncing
         notification to the presence writer
-
-        Args:
-            user_id (str)
         """
         self.users_going_offline[user_id] = self.clock.time_msec()
 
-    def send_stop_syncing(self):
+    def send_stop_syncing(self) -> None:
         """Check if there are any users who have stopped syncing a while ago and
         haven't come back yet. If there are poke the presence writer about them.
         """
@@ -430,7 +425,9 @@ class WorkerPresenceHandler(BasePresenceHandler):
 
         return _user_syncing()
 
-    async def notify_from_replication(self, states, stream_id):
+    async def notify_from_replication(
+        self, states: List[UserPresenceState], stream_id: int
+    ) -> None:
         parties = await get_interested_parties(self.store, self.presence_router, states)
         room_ids_to_states, users_to_states = parties
 
@@ -478,7 +475,12 @@ class WorkerPresenceHandler(BasePresenceHandler):
             if count > 0
         ]
 
-    async def set_state(self, target_user, state, ignore_status_msg=False):
+    async def set_state(
+        self,
+        target_user: UserID,
+        state: JsonDict,
+        ignore_status_msg: bool = False,
+    ) -> None:
         """Set the presence state of the user."""
         presence = state["presence"]
 
@@ -508,7 +510,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
             ignore_status_msg=ignore_status_msg,
         )
 
-    async def bump_presence_active_time(self, user):
+    async def bump_presence_active_time(self, user: UserID) -> None:
         """We've seen the user do something that indicates they're interacting
         with the app.
         """
@@ -592,8 +594,8 @@ class PresenceHandler(BasePresenceHandler):
         # we assume that all the sync requests on that process have stopped.
         # Stored as a dict from process_id to set of user_id, and a dict of
         # process_id to millisecond timestamp last updated.
-        self.external_process_to_current_syncs = {}  # type: Dict[int, Set[str]]
-        self.external_process_last_updated_ms = {}  # type: Dict[int, int]
+        self.external_process_to_current_syncs = {}  # type: Dict[str, Set[str]]
+        self.external_process_last_updated_ms = {}  # type: Dict[str, int]
 
         self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
 
@@ -633,7 +635,7 @@ class PresenceHandler(BasePresenceHandler):
         self._event_pos = self.store.get_current_events_token()
         self._event_processing = False
 
-    async def _on_shutdown(self):
+    async def _on_shutdown(self) -> None:
         """Gets called when shutting down. This lets us persist any updates that
         we haven't yet persisted, e.g. updates that only changes some internal
         timers. This allows changes to persist across startup without having to
@@ -662,7 +664,7 @@ class PresenceHandler(BasePresenceHandler):
             )
         logger.info("Finished _on_shutdown")
 
-    async def _persist_unpersisted_changes(self):
+    async def _persist_unpersisted_changes(self) -> None:
         """We periodically persist the unpersisted changes, as otherwise they
         may stack up and slow down shutdown times.
         """
@@ -762,7 +764,7 @@ class PresenceHandler(BasePresenceHandler):
                         states, destinations
                     )
 
-    async def _handle_timeouts(self):
+    async def _handle_timeouts(self) -> None:
         """Checks the presence of users that have timed out and updates as
         appropriate.
         """
@@ -814,7 +816,7 @@ class PresenceHandler(BasePresenceHandler):
 
         return await self._update_states(changes)
 
-    async def bump_presence_active_time(self, user):
+    async def bump_presence_active_time(self, user: UserID) -> None:
         """We've seen the user do something that indicates they're interacting
         with the app.
         """
@@ -911,17 +913,17 @@ class PresenceHandler(BasePresenceHandler):
         return []
 
     async def update_external_syncs_row(
-        self, process_id, user_id, is_syncing, sync_time_msec
-    ):
+        self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int
+    ) -> None:
         """Update the syncing users for an external process as a delta.
 
         Args:
-            process_id (str): An identifier for the process the users are
+            process_id: An identifier for the process the users are
                 syncing against. This allows synapse to process updates
                 as user start and stop syncing against a given process.
-            user_id (str): The user who has started or stopped syncing
-            is_syncing (bool): Whether or not the user is now syncing
-            sync_time_msec(int): Time in ms when the user was last syncing
+            user_id: The user who has started or stopped syncing
+            is_syncing: Whether or not the user is now syncing
+            sync_time_msec: Time in ms when the user was last syncing
         """
         with (await self.external_sync_linearizer.queue(process_id)):
             prev_state = await self.current_state_for_user(user_id)
@@ -958,7 +960,7 @@ class PresenceHandler(BasePresenceHandler):
 
             self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
 
-    async def update_external_syncs_clear(self, process_id):
+    async def update_external_syncs_clear(self, process_id: str) -> None:
         """Marks all users that had been marked as syncing by a given process
         as offline.
 
@@ -979,12 +981,12 @@ class PresenceHandler(BasePresenceHandler):
             )
             self.external_process_last_updated_ms.pop(process_id, None)
 
-    async def current_state_for_user(self, user_id):
+    async def current_state_for_user(self, user_id: str) -> UserPresenceState:
         """Get the current presence state for a user."""
         res = await self.current_state_for_users([user_id])
         return res[user_id]
 
-    async def _persist_and_notify(self, states):
+    async def _persist_and_notify(self, states: List[UserPresenceState]) -> None:
         """Persist states in the database, poke the notifier and send to
         interested remote servers
         """
@@ -1005,7 +1007,7 @@ class PresenceHandler(BasePresenceHandler):
         # stream (which is updated by `store.update_presence`).
         await self.maybe_send_presence_to_interested_destinations(states)
 
-    async def incoming_presence(self, origin, content):
+    async def incoming_presence(self, origin: str, content: JsonDict) -> None:
         """Called when we receive a `m.presence` EDU from a remote server."""
         if not self._presence_enabled:
             return
@@ -1055,7 +1057,9 @@ class PresenceHandler(BasePresenceHandler):
             federation_presence_counter.inc(len(updates))
             await self._update_states(updates)
 
-    async def set_state(self, target_user, state, ignore_status_msg=False):
+    async def set_state(
+        self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False
+    ) -> None:
         """Set the presence state of the user."""
         status_msg = state.get("status_msg", None)
         presence = state["presence"]
@@ -1089,7 +1093,7 @@ class PresenceHandler(BasePresenceHandler):
 
         await self._update_states([prev_state.copy_and_replace(**new_fields)])
 
-    async def is_visible(self, observed_user, observer_user):
+    async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool:
         """Returns whether a user can see another user's presence."""
         observer_room_ids = await self.store.get_rooms_for_user(
             observer_user.to_string()
@@ -1144,7 +1148,7 @@ class PresenceHandler(BasePresenceHandler):
         )
         return rows
 
-    def notify_new_event(self):
+    def notify_new_event(self) -> None:
         """Called when new events have happened. Handles users and servers
         joining rooms and require being sent presence.
         """
@@ -1163,7 +1167,7 @@ class PresenceHandler(BasePresenceHandler):
 
         run_as_background_process("presence.notify_new_event", _process_presence)
 
-    async def _unsafe_process(self):
+    async def _unsafe_process(self) -> None:
         # Loop round handling deltas until we're up to date
         while True:
             with Measure(self.clock, "presence_delta"):
@@ -1188,7 +1192,7 @@ class PresenceHandler(BasePresenceHandler):
                     max_pos
                 )
 
-    async def _handle_state_delta(self, deltas):
+    async def _handle_state_delta(self, deltas: List[JsonDict]) -> None:
         """Process current state deltas to find new joins that need to be
         handled.
         """
@@ -1311,7 +1315,7 @@ class PresenceHandler(BasePresenceHandler):
             return [remote_host], states
 
 
-def should_notify(old_state, new_state):
+def should_notify(old_state: UserPresenceState, new_state: UserPresenceState) -> bool:
     """Decides if a presence state change should be sent to interested parties."""
     if old_state == new_state:
         return False
@@ -1347,7 +1351,9 @@ def should_notify(old_state, new_state):
     return False
 
 
-def format_user_presence_state(state, now, include_user_id=True):
+def format_user_presence_state(
+    state: UserPresenceState, now: int, include_user_id: bool = True
+) -> JsonDict:
     """Convert UserPresenceState to a format that can be sent down to clients
     and to other servers.
 
@@ -1385,11 +1391,11 @@ class PresenceEventSource:
     @log_function
     async def get_new_events(
         self,
-        user,
-        from_key,
-        room_ids=None,
-        include_offline=True,
-        explicit_room_id=None,
+        user: UserID,
+        from_key: Optional[int],
+        room_ids: Optional[List[str]] = None,
+        include_offline: bool = True,
+        explicit_room_id: Optional[str] = None,
         **kwargs,
     ) -> Tuple[List[UserPresenceState], int]:
         # The process for getting presence events are:
@@ -1594,7 +1600,7 @@ class PresenceEventSource:
             if update.state != PresenceState.OFFLINE
         ]
 
-    def get_current_key(self):
+    def get_current_key(self) -> int:
         return self.store.get_current_presence_token()
 
     @cached(num_args=2, cache_context=True)
@@ -1654,15 +1660,20 @@ class PresenceEventSource:
         return users_interested_in
 
 
-def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
+def handle_timeouts(
+    user_states: List[UserPresenceState],
+    is_mine_fn: Callable[[str], bool],
+    syncing_user_ids: Set[str],
+    now: int,
+) -> List[UserPresenceState]:
     """Checks the presence of users that have timed out and updates as
     appropriate.
 
     Args:
-        user_states(list): List of UserPresenceState's to check.
-        is_mine_fn (fn): Function that returns if a user_id is ours
-        syncing_user_ids (set): Set of user_ids with active syncs.
-        now (int): Current time in ms.
+        user_states: List of UserPresenceState's to check.
+        is_mine_fn: Function that returns if a user_id is ours
+        syncing_user_ids: Set of user_ids with active syncs.
+        now: Current time in ms.
 
     Returns:
         List of UserPresenceState updates
@@ -1679,14 +1690,16 @@ def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
     return list(changes.values())
 
 
-def handle_timeout(state, is_mine, syncing_user_ids, now):
+def handle_timeout(
+    state: UserPresenceState, is_mine: bool, syncing_user_ids: Set[str], now: int
+) -> Optional[UserPresenceState]:
     """Checks the presence of the user to see if any of the timers have elapsed
 
     Args:
-        state (UserPresenceState)
-        is_mine (bool): Whether the user is ours
-        syncing_user_ids (set): Set of user_ids with active syncs.
-        now (int): Current time in ms.
+        state
+        is_mine: Whether the user is ours
+        syncing_user_ids: Set of user_ids with active syncs.
+        now: Current time in ms.
 
     Returns:
         A UserPresenceState update or None if no update.
@@ -1738,23 +1751,29 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
     return state if changed else None
 
 
-def handle_update(prev_state, new_state, is_mine, wheel_timer, now):
+def handle_update(
+    prev_state: UserPresenceState,
+    new_state: UserPresenceState,
+    is_mine: bool,
+    wheel_timer: WheelTimer,
+    now: int,
+) -> Tuple[UserPresenceState, bool, bool]:
     """Given a presence update:
         1. Add any appropriate timers.
         2. Check if we should notify anyone.
 
     Args:
-        prev_state (UserPresenceState)
-        new_state (UserPresenceState)
-        is_mine (bool): Whether the user is ours
-        wheel_timer (WheelTimer)
-        now (int): Time now in ms
+        prev_state
+        new_state
+        is_mine: Whether the user is ours
+        wheel_timer
+        now: Time now in ms
 
     Returns:
         3-tuple: `(new_state, persist_and_notify, federation_ping)` where:
             - new_state: is the state to actually persist
-            - persist_and_notify (bool): whether to persist and notify people
-            - federation_ping (bool): whether we should send a ping over federation
+            - persist_and_notify: whether to persist and notify people
+            - federation_ping: whether we should send a ping over federation
     """
     user_id = new_state.user_id
 
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 2c5bada1d8..20700fc5a8 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1044,7 +1044,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
 
 class RoomMemberMasterHandler(RoomMemberHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.distributor = hs.get_distributor()
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 0eeb7c03f2..5414ce77d8 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any
+from typing import TYPE_CHECKING, Any
 
 from twisted.web.client import PartialDownloadError
 
@@ -22,13 +22,16 @@ from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.util import json_decoder
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class UserInteractiveAuthChecker:
     """Abstract base class for an interactive auth checker"""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         pass
 
     def is_enabled(self) -> bool:
@@ -57,10 +60,10 @@ class UserInteractiveAuthChecker:
 class DummyAuthChecker(UserInteractiveAuthChecker):
     AUTH_TYPE = LoginType.DUMMY
 
-    def is_enabled(self):
+    def is_enabled(self) -> bool:
         return True
 
-    async def check_auth(self, authdict, clientip):
+    async def check_auth(self, authdict: dict, clientip: str) -> Any:
         return True
 
 
@@ -70,24 +73,24 @@ class TermsAuthChecker(UserInteractiveAuthChecker):
     def is_enabled(self):
         return True
 
-    async def check_auth(self, authdict, clientip):
+    async def check_auth(self, authdict: dict, clientip: str) -> Any:
         return True
 
 
 class RecaptchaAuthChecker(UserInteractiveAuthChecker):
     AUTH_TYPE = LoginType.RECAPTCHA
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self._enabled = bool(hs.config.recaptcha_private_key)
         self._http_client = hs.get_proxied_http_client()
         self._url = hs.config.recaptcha_siteverify_api
         self._secret = hs.config.recaptcha_private_key
 
-    def is_enabled(self):
+    def is_enabled(self) -> bool:
         return self._enabled
 
-    async def check_auth(self, authdict, clientip):
+    async def check_auth(self, authdict: dict, clientip: str) -> Any:
         try:
             user_response = authdict["response"]
         except KeyError:
@@ -132,11 +135,11 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
 
 
 class _BaseThreepidAuthChecker:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
 
-    async def _check_threepid(self, medium, authdict):
+    async def _check_threepid(self, medium: str, authdict: dict) -> dict:
         if "threepid_creds" not in authdict:
             raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
 
@@ -206,31 +209,31 @@ class _BaseThreepidAuthChecker:
 class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
     AUTH_TYPE = LoginType.EMAIL_IDENTITY
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         UserInteractiveAuthChecker.__init__(self, hs)
         _BaseThreepidAuthChecker.__init__(self, hs)
 
-    def is_enabled(self):
+    def is_enabled(self) -> bool:
         return self.hs.config.threepid_behaviour_email in (
             ThreepidBehaviour.REMOTE,
             ThreepidBehaviour.LOCAL,
         )
 
-    async def check_auth(self, authdict, clientip):
+    async def check_auth(self, authdict: dict, clientip: str) -> Any:
         return await self._check_threepid("email", authdict)
 
 
 class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
     AUTH_TYPE = LoginType.MSISDN
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         UserInteractiveAuthChecker.__init__(self, hs)
         _BaseThreepidAuthChecker.__init__(self, hs)
 
-    def is_enabled(self):
+    def is_enabled(self) -> bool:
         return bool(self.hs.config.account_threepid_delegate_msisdn)
 
-    async def check_auth(self, authdict, clientip):
+    async def check_auth(self, authdict: dict, clientip: str) -> Any:
         return await self._check_threepid("msisdn", authdict)
 
 
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index d0cf121743..f289ffe3d0 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -37,9 +37,11 @@ from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester
 from synapse.util import json_decoder
 
 if TYPE_CHECKING:
+    from synapse.api.auth import Auth
+    from synapse.handlers.pagination import PaginationHandler
+    from synapse.handlers.room import RoomShutdownHandler
     from synapse.server import HomeServer
 
-
 logger = logging.getLogger(__name__)
 
 
@@ -146,50 +148,14 @@ class DeleteRoomRestServlet(RestServlet):
     async def on_POST(
         self, request: SynapseRequest, room_id: str
     ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
-        await assert_user_is_admin(self.auth, requester.user)
-
-        content = parse_json_object_from_request(request)
-
-        block = content.get("block", False)
-        if not isinstance(block, bool):
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Param 'block' must be a boolean, if given",
-                Codes.BAD_JSON,
-            )
-
-        purge = content.get("purge", True)
-        if not isinstance(purge, bool):
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Param 'purge' must be a boolean, if given",
-                Codes.BAD_JSON,
-            )
-
-        force_purge = content.get("force_purge", False)
-        if not isinstance(force_purge, bool):
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Param 'force_purge' must be a boolean, if given",
-                Codes.BAD_JSON,
-            )
-
-        ret = await self.room_shutdown_handler.shutdown_room(
-            room_id=room_id,
-            new_room_user_id=content.get("new_room_user_id"),
-            new_room_name=content.get("room_name"),
-            message=content.get("message"),
-            requester_user_id=requester.user.to_string(),
-            block=block,
+        return await _delete_room(
+            request,
+            room_id,
+            self.auth,
+            self.room_shutdown_handler,
+            self.pagination_handler,
         )
 
-        # Purge room
-        if purge:
-            await self.pagination_handler.purge_room(room_id, force=force_purge)
-
-        return (200, ret)
-
 
 class ListRoomRestServlet(RestServlet):
     """
@@ -282,7 +248,22 @@ class ListRoomRestServlet(RestServlet):
 
 
 class RoomRestServlet(RestServlet):
-    """Get room details.
+    """Manage a room.
+
+    On GET : Get details of a room.
+
+    On DELETE : Delete a room from server.
+
+    It is a combination and improvement of shutdown and purge room.
+
+    Shuts down a room by removing all local users from the room.
+    Blocking all future invites and joins to the room is optional.
+
+    If desired any local aliases will be repointed to a new room
+    created by `new_room_user_id` and kicked users will be auto-
+    joined to the new room.
+
+    If 'purge' is true, it will remove all traces of a room from the database.
 
     TODO: Add on_POST to allow room creation without joining the room
     """
@@ -293,6 +274,8 @@ class RoomRestServlet(RestServlet):
         self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
+        self.room_shutdown_handler = hs.get_room_shutdown_handler()
+        self.pagination_handler = hs.get_pagination_handler()
 
     async def on_GET(
         self, request: SynapseRequest, room_id: str
@@ -308,6 +291,17 @@ class RoomRestServlet(RestServlet):
 
         return (200, ret)
 
+    async def on_DELETE(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
+        return await _delete_room(
+            request,
+            room_id,
+            self.auth,
+            self.room_shutdown_handler,
+            self.pagination_handler,
+        )
+
 
 class RoomMembersRestServlet(RestServlet):
     """
@@ -694,3 +688,55 @@ class RoomEventContextServlet(RestServlet):
         )
 
         return 200, results
+
+
+async def _delete_room(
+    request: SynapseRequest,
+    room_id: str,
+    auth: "Auth",
+    room_shutdown_handler: "RoomShutdownHandler",
+    pagination_handler: "PaginationHandler",
+) -> Tuple[int, JsonDict]:
+    requester = await auth.get_user_by_req(request)
+    await assert_user_is_admin(auth, requester.user)
+
+    content = parse_json_object_from_request(request)
+
+    block = content.get("block", False)
+    if not isinstance(block, bool):
+        raise SynapseError(
+            HTTPStatus.BAD_REQUEST,
+            "Param 'block' must be a boolean, if given",
+            Codes.BAD_JSON,
+        )
+
+    purge = content.get("purge", True)
+    if not isinstance(purge, bool):
+        raise SynapseError(
+            HTTPStatus.BAD_REQUEST,
+            "Param 'purge' must be a boolean, if given",
+            Codes.BAD_JSON,
+        )
+
+    force_purge = content.get("force_purge", False)
+    if not isinstance(force_purge, bool):
+        raise SynapseError(
+            HTTPStatus.BAD_REQUEST,
+            "Param 'force_purge' must be a boolean, if given",
+            Codes.BAD_JSON,
+        )
+
+    ret = await room_shutdown_handler.shutdown_room(
+        room_id=room_id,
+        new_room_user_id=content.get("new_room_user_id"),
+        new_room_name=content.get("room_name"),
+        message=content.get("message"),
+        requester_user_id=requester.user.to_string(),
+        block=block,
+    )
+
+    # Purge room
+    if purge:
+        await pagination_handler.purge_room(room_id, force=force_purge)
+
+    return (200, ret)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index bd39c095af..a761ad603b 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -715,7 +715,9 @@ class DatabasePool:
             # pool).
             assert not self.engine.in_transaction(conn)
 
-            with LoggingContext("runWithConnection", parent_context) as context:
+            with LoggingContext(
+                str(curr_context), parent_context=parent_context
+            ) as context:
                 sched_duration_sec = monotonic_time() - start_time
                 sql_scheduling_timer.observe(sched_duration_sec)
                 context.add_database_scheduled(sched_duration_sec)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index a21d34fcb4..10b0ec6b75 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -17,8 +17,10 @@ from functools import wraps
 from typing import (
     Any,
     Callable,
+    Collection,
     Generic,
     Iterable,
+    List,
     Optional,
     Type,
     TypeVar,
@@ -57,13 +59,56 @@ class _Node:
     __slots__ = ["prev_node", "next_node", "key", "value", "callbacks"]
 
     def __init__(
-        self, prev_node, next_node, key, value, callbacks: Optional[set] = None
+        self,
+        prev_node,
+        next_node,
+        key,
+        value,
+        callbacks: Collection[Callable[[], None]] = (),
     ):
         self.prev_node = prev_node
         self.next_node = next_node
         self.key = key
         self.value = value
-        self.callbacks = callbacks or set()
+
+        # Set of callbacks to run when the node gets deleted. We store as a list
+        # rather than a set to keep memory usage down (and since we expect few
+        # entries per node, the performance of checking for duplication in a
+        # list vs using a set is negligible).
+        #
+        # Note that we store this as an optional list to keep the memory
+        # footprint down. Storing `None` is free as its a singleton, while empty
+        # lists are 56 bytes (and empty sets are 216 bytes, if we did the naive
+        # thing and used sets).
+        self.callbacks = None  # type: Optional[List[Callable[[], None]]]
+
+        self.add_callbacks(callbacks)
+
+    def add_callbacks(self, callbacks: Collection[Callable[[], None]]) -> None:
+        """Add to stored list of callbacks, removing duplicates."""
+
+        if not callbacks:
+            return
+
+        if not self.callbacks:
+            self.callbacks = []
+
+        for callback in callbacks:
+            if callback not in self.callbacks:
+                self.callbacks.append(callback)
+
+    def run_and_clear_callbacks(self) -> None:
+        """Run all callbacks and clear the stored list of callbacks. Used when
+        the node is being deleted.
+        """
+
+        if not self.callbacks:
+            return
+
+        for callback in self.callbacks:
+            callback()
+
+        self.callbacks = None
 
 
 class LruCache(Generic[KT, VT]):
@@ -177,10 +222,10 @@ class LruCache(Generic[KT, VT]):
 
         self.len = synchronized(cache_len)
 
-        def add_node(key, value, callbacks: Optional[set] = None):
+        def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()):
             prev_node = list_root
             next_node = prev_node.next_node
-            node = _Node(prev_node, next_node, key, value, callbacks or set())
+            node = _Node(prev_node, next_node, key, value, callbacks)
             prev_node.next_node = node
             next_node.prev_node = node
             cache[key] = node
@@ -211,16 +256,15 @@ class LruCache(Generic[KT, VT]):
                 deleted_len = size_callback(node.value)
                 cached_cache_len[0] -= deleted_len
 
-            for cb in node.callbacks:
-                cb()
-            node.callbacks.clear()
+            node.run_and_clear_callbacks()
+
             return deleted_len
 
         @overload
         def cache_get(
             key: KT,
             default: Literal[None] = None,
-            callbacks: Iterable[Callable[[], None]] = ...,
+            callbacks: Collection[Callable[[], None]] = ...,
             update_metrics: bool = ...,
         ) -> Optional[VT]:
             ...
@@ -229,7 +273,7 @@ class LruCache(Generic[KT, VT]):
         def cache_get(
             key: KT,
             default: T,
-            callbacks: Iterable[Callable[[], None]] = ...,
+            callbacks: Collection[Callable[[], None]] = ...,
             update_metrics: bool = ...,
         ) -> Union[T, VT]:
             ...
@@ -238,13 +282,13 @@ class LruCache(Generic[KT, VT]):
         def cache_get(
             key: KT,
             default: Optional[T] = None,
-            callbacks: Iterable[Callable[[], None]] = (),
+            callbacks: Collection[Callable[[], None]] = (),
             update_metrics: bool = True,
         ):
             node = cache.get(key, None)
             if node is not None:
                 move_node_to_front(node)
-                node.callbacks.update(callbacks)
+                node.add_callbacks(callbacks)
                 if update_metrics and metrics:
                     metrics.inc_hits()
                 return node.value
@@ -260,10 +304,8 @@ class LruCache(Generic[KT, VT]):
                 # We sometimes store large objects, e.g. dicts, which cause
                 # the inequality check to take a long time. So let's only do
                 # the check if we have some callbacks to call.
-                if node.callbacks and value != node.value:
-                    for cb in node.callbacks:
-                        cb()
-                    node.callbacks.clear()
+                if value != node.value:
+                    node.run_and_clear_callbacks()
 
                 # We don't bother to protect this by value != node.value as
                 # generally size_callback will be cheap compared with equality
@@ -273,7 +315,7 @@ class LruCache(Generic[KT, VT]):
                     cached_cache_len[0] -= size_callback(node.value)
                     cached_cache_len[0] += size_callback(value)
 
-                node.callbacks.update(callbacks)
+                node.add_callbacks(callbacks)
 
                 move_node_to_front(node)
                 node.value = value
@@ -326,8 +368,7 @@ class LruCache(Generic[KT, VT]):
             list_root.next_node = list_root
             list_root.prev_node = list_root
             for node in cache.values():
-                for cb in node.callbacks:
-                    cb()
+                node.run_and_clear_callbacks()
             cache.clear()
             if size_callback:
                 cached_cache_len[0] = 0