summary refs log tree commit diff
path: root/synapse/handlers/admin.py
diff options
context:
space:
mode:
authorBen Banfield-Zanin <benbz@matrix.org>2021-02-16 13:33:20 +0000
committerBen Banfield-Zanin <benbz@matrix.org>2021-02-16 13:33:20 +0000
commitdcf1b9c276e22bb6f5200fc029301c4d40e87a1f (patch)
tree1f5badce24645d99534133a7a989069906088fff /synapse/handlers/admin.py
parentMerge remote-tracking branch 'origin/release-v1.24.0' into bbz/info-mainline-... (diff)
parentFixup CHANGES (diff)
downloadsynapse-dcf1b9c276e22bb6f5200fc029301c4d40e87a1f.tar.xz
Merge remote-tracking branch 'origin/release-v1.27.0' into bbz/info-mainline-1.27.0 github/bbz/info-mainline-1.27.0 bbz/info-mainline-1.27.0
Diffstat (limited to 'synapse/handlers/admin.py')
-rw-r--r--synapse/handlers/admin.py63
1 files changed, 38 insertions, 25 deletions
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index a703944543..37e63da9b1 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -13,27 +13,31 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import abc
 import logging
-from typing import List
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
 
 from synapse.api.constants import Membership
-from synapse.events import FrozenEvent
-from synapse.types import RoomStreamToken, StateMap
+from synapse.events import EventBase
+from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class AdminHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
 
-    async def get_whois(self, user):
+    async def get_whois(self, user: UserID) -> JsonDict:
         connections = []
 
         sessions = await self.store.get_user_ip_and_agents(user)
@@ -53,7 +57,7 @@ class AdminHandler(BaseHandler):
 
         return ret
 
-    async def get_user(self, user):
+    async def get_user(self, user: UserID) -> Optional[JsonDict]:
         """Function to get user details"""
         ret = await self.store.get_user_by_id(user.to_string())
         if ret:
@@ -64,12 +68,12 @@ class AdminHandler(BaseHandler):
             ret["threepids"] = threepids
         return ret
 
-    async def export_user_data(self, user_id, writer):
+    async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
         """Write all data we have on the user to the given writer.
 
         Args:
-            user_id (str)
-            writer (ExfiltrationWriter)
+            user_id: The user ID to fetch data of.
+            writer: The writer to write to.
 
         Returns:
             Resolves when all data for a user has been written.
@@ -128,7 +132,8 @@ class AdminHandler(BaseHandler):
             from_key = RoomStreamToken(0, 0)
             to_key = RoomStreamToken(None, stream_ordering)
 
-            written_events = set()  # Events that we've processed in this room
+            # Events that we've processed in this room
+            written_events = set()  # type: Set[str]
 
             # We need to track gaps in the events stream so that we can then
             # write out the state at those events. We do this by keeping track
@@ -140,8 +145,8 @@ class AdminHandler(BaseHandler):
 
             # The reverse mapping to above, i.e. map from unseen event to events
             # that have the unseen event in their prev_events, i.e. the unseen
-            # events "children". dict[str, set[str]]
-            unseen_to_child_events = {}
+            # events "children".
+            unseen_to_child_events = {}  # type: Dict[str, Set[str]]
 
             # We fetch events in the room the user could see by fetching *all*
             # events that we have and then filtering, this isn't the most
@@ -197,38 +202,46 @@ class AdminHandler(BaseHandler):
         return writer.finished()
 
 
-class ExfiltrationWriter:
+class ExfiltrationWriter(metaclass=abc.ABCMeta):
     """Interface used to specify how to write exported data.
     """
 
-    def write_events(self, room_id: str, events: List[FrozenEvent]):
+    @abc.abstractmethod
+    def write_events(self, room_id: str, events: List[EventBase]) -> None:
         """Write a batch of events for a room.
         """
-        pass
+        raise NotImplementedError()
 
-    def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]):
+    @abc.abstractmethod
+    def write_state(
+        self, room_id: str, event_id: str, state: StateMap[EventBase]
+    ) -> None:
         """Write the state at the given event in the room.
 
         This only gets called for backward extremities rather than for each
         event.
         """
-        pass
+        raise NotImplementedError()
 
-    def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]):
+    @abc.abstractmethod
+    def write_invite(
+        self, room_id: str, event: EventBase, state: StateMap[dict]
+    ) -> None:
         """Write an invite for the room, with associated invite state.
 
         Args:
-            room_id
-            event
-            state: A subset of the state at the
-                invite, with a subset of the event keys (type, state_key
-                content and sender)
+            room_id: The room ID the invite is for.
+            event: The invite event.
+            state: A subset of the state at the invite, with a subset of the
+                event keys (type, state_key content and sender).
         """
+        raise NotImplementedError()
 
-    def finished(self):
+    @abc.abstractmethod
+    def finished(self) -> Any:
         """Called when all data has successfully been exported and written.
 
         This functions return value is passed to the caller of
         `export_user_data`.
         """
-        pass
+        raise NotImplementedError()