summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-12-29 17:42:10 -0500
committerGitHub <noreply@github.com>2020-12-29 17:42:10 -0500
commit9999eb2d0270519f397343c90dfb394823d08e82 (patch)
treeb737049dc33bbd76da7d3ba70b8f817a740c611f /synapse
parentValidate input parameters for the sendToDevice API. (#8975) (diff)
downloadsynapse-9999eb2d0270519f397343c90dfb394823d08e82.tar.xz
Add type hints to admin and room list handlers. (#8973)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/admin.py63
-rw-r--r--synapse/handlers/room_list.py94
-rw-r--r--synapse/storage/databases/main/client_ips.py7
3 files changed, 94 insertions, 70 deletions
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index a703944543..37e63da9b1 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -13,27 +13,31 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import abc
 import logging
-from typing import List
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
 
 from synapse.api.constants import Membership
-from synapse.events import FrozenEvent
-from synapse.types import RoomStreamToken, StateMap
+from synapse.events import EventBase
+from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class AdminHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
 
-    async def get_whois(self, user):
+    async def get_whois(self, user: UserID) -> JsonDict:
         connections = []
 
         sessions = await self.store.get_user_ip_and_agents(user)
@@ -53,7 +57,7 @@ class AdminHandler(BaseHandler):
 
         return ret
 
-    async def get_user(self, user):
+    async def get_user(self, user: UserID) -> Optional[JsonDict]:
         """Function to get user details"""
         ret = await self.store.get_user_by_id(user.to_string())
         if ret:
@@ -64,12 +68,12 @@ class AdminHandler(BaseHandler):
             ret["threepids"] = threepids
         return ret
 
-    async def export_user_data(self, user_id, writer):
+    async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
         """Write all data we have on the user to the given writer.
 
         Args:
-            user_id (str)
-            writer (ExfiltrationWriter)
+            user_id: The user ID to fetch data of.
+            writer: The writer to write to.
 
         Returns:
             Resolves when all data for a user has been written.
@@ -128,7 +132,8 @@ class AdminHandler(BaseHandler):
             from_key = RoomStreamToken(0, 0)
             to_key = RoomStreamToken(None, stream_ordering)
 
-            written_events = set()  # Events that we've processed in this room
+            # Events that we've processed in this room
+            written_events = set()  # type: Set[str]
 
             # We need to track gaps in the events stream so that we can then
             # write out the state at those events. We do this by keeping track
@@ -140,8 +145,8 @@ class AdminHandler(BaseHandler):
 
             # The reverse mapping to above, i.e. map from unseen event to events
             # that have the unseen event in their prev_events, i.e. the unseen
-            # events "children". dict[str, set[str]]
-            unseen_to_child_events = {}
+            # events "children".
+            unseen_to_child_events = {}  # type: Dict[str, Set[str]]
 
             # We fetch events in the room the user could see by fetching *all*
             # events that we have and then filtering, this isn't the most
@@ -197,38 +202,46 @@ class AdminHandler(BaseHandler):
         return writer.finished()
 
 
-class ExfiltrationWriter:
+class ExfiltrationWriter(metaclass=abc.ABCMeta):
     """Interface used to specify how to write exported data.
     """
 
-    def write_events(self, room_id: str, events: List[FrozenEvent]):
+    @abc.abstractmethod
+    def write_events(self, room_id: str, events: List[EventBase]) -> None:
         """Write a batch of events for a room.
         """
-        pass
+        raise NotImplementedError()
 
-    def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]):
+    @abc.abstractmethod
+    def write_state(
+        self, room_id: str, event_id: str, state: StateMap[EventBase]
+    ) -> None:
         """Write the state at the given event in the room.
 
         This only gets called for backward extremities rather than for each
         event.
         """
-        pass
+        raise NotImplementedError()
 
-    def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]):
+    @abc.abstractmethod
+    def write_invite(
+        self, room_id: str, event: EventBase, state: StateMap[dict]
+    ) -> None:
         """Write an invite for the room, with associated invite state.
 
         Args:
-            room_id
-            event
-            state: A subset of the state at the
-                invite, with a subset of the event keys (type, state_key
-                content and sender)
+            room_id: The room ID the invite is for.
+            event: The invite event.
+            state: A subset of the state at the invite, with a subset of the
+                event keys (type, state_key content and sender).
         """
+        raise NotImplementedError()
 
-    def finished(self):
+    @abc.abstractmethod
+    def finished(self) -> Any:
         """Called when all data has successfully been exported and written.
 
         This functions return value is passed to the caller of
         `export_user_data`.
         """
-        pass
+        raise NotImplementedError()
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index bf58d302b0..14f14db449 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -15,19 +15,22 @@
 
 import logging
 from collections import namedtuple
-from typing import Any, Dict, Optional
+from typing import TYPE_CHECKING, Optional, Tuple
 
 import msgpack
 from unpaddedbase64 import decode_base64, encode_base64
 
 from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
 from synapse.api.errors import Codes, HttpResponseException
-from synapse.types import ThirdPartyInstanceID
+from synapse.types import JsonDict, ThirdPartyInstanceID
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.response_cache import ResponseCache
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
@@ -37,37 +40,38 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
 
 
 class RoomListHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.enable_room_list_search = hs.config.enable_room_list_search
-        self.response_cache = ResponseCache(hs, "room_list")
+        self.response_cache = ResponseCache(
+            hs, "room_list"
+        )  # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]
         self.remote_response_cache = ResponseCache(
             hs, "remote_room_list", timeout_ms=30 * 1000
-        )
+        )  # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
 
     async def get_local_public_room_list(
         self,
-        limit=None,
-        since_token=None,
-        search_filter=None,
-        network_tuple=EMPTY_THIRD_PARTY_ID,
-        from_federation=False,
-    ):
+        limit: Optional[int] = None,
+        since_token: Optional[str] = None,
+        search_filter: Optional[dict] = None,
+        network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
+        from_federation: bool = False,
+    ) -> JsonDict:
         """Generate a local public room list.
 
         There are multiple different lists: the main one plus one per third
         party network. A client can ask for a specific list or to return all.
 
         Args:
-            limit (int|None)
-            since_token (str|None)
-            search_filter (dict|None)
-            network_tuple (ThirdPartyInstanceID): Which public list to use.
+            limit
+            since_token
+            search_filter
+            network_tuple: Which public list to use.
                 This can be (None, None) to indicate the main list, or a particular
                 appservice and network id to use an appservice specific one.
                 Setting to None returns all public rooms across all lists.
-            from_federation (bool): true iff the request comes from the federation
-                API
+            from_federation: true iff the request comes from the federation API
         """
         if not self.enable_room_list_search:
             return {"chunk": [], "total_room_count_estimate": 0}
@@ -107,10 +111,10 @@ class RoomListHandler(BaseHandler):
         self,
         limit: Optional[int] = None,
         since_token: Optional[str] = None,
-        search_filter: Optional[Dict] = None,
+        search_filter: Optional[dict] = None,
         network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
         from_federation: bool = False,
-    ) -> Dict[str, Any]:
+    ) -> JsonDict:
         """Generate a public room list.
         Args:
             limit: Maximum amount of rooms to return.
@@ -131,13 +135,17 @@ class RoomListHandler(BaseHandler):
         if since_token:
             batch_token = RoomListNextBatch.from_token(since_token)
 
-            bounds = (batch_token.last_joined_members, batch_token.last_room_id)
+            bounds = (
+                batch_token.last_joined_members,
+                batch_token.last_room_id,
+            )  # type: Optional[Tuple[int, str]]
             forwards = batch_token.direction_is_forward
+            has_batch_token = True
         else:
-            batch_token = None
             bounds = None
 
             forwards = True
+            has_batch_token = False
 
         # we request one more than wanted to see if there are more pages to come
         probing_limit = limit + 1 if limit is not None else None
@@ -169,7 +177,7 @@ class RoomListHandler(BaseHandler):
 
         results = [build_room_entry(r) for r in results]
 
-        response = {}
+        response = {}  # type: JsonDict
         num_results = len(results)
         if limit is not None:
             more_to_come = num_results == probing_limit
@@ -187,7 +195,7 @@ class RoomListHandler(BaseHandler):
             initial_entry = results[0]
 
             if forwards:
-                if batch_token:
+                if has_batch_token:
                     # If there was a token given then we assume that there
                     # must be previous results.
                     response["prev_batch"] = RoomListNextBatch(
@@ -203,7 +211,7 @@ class RoomListHandler(BaseHandler):
                         direction_is_forward=True,
                     ).to_token()
             else:
-                if batch_token:
+                if has_batch_token:
                     response["next_batch"] = RoomListNextBatch(
                         last_joined_members=final_entry["num_joined_members"],
                         last_room_id=final_entry["room_id"],
@@ -293,7 +301,7 @@ class RoomListHandler(BaseHandler):
                 return None
 
         # Return whether this room is open to federation users or not
-        create_event = current_state.get((EventTypes.Create, ""))
+        create_event = current_state[EventTypes.Create, ""]
         result["m.federate"] = create_event.content.get("m.federate", True)
 
         name_event = current_state.get((EventTypes.Name, ""))
@@ -336,13 +344,13 @@ class RoomListHandler(BaseHandler):
 
     async def get_remote_public_room_list(
         self,
-        server_name,
-        limit=None,
-        since_token=None,
-        search_filter=None,
-        include_all_networks=False,
-        third_party_instance_id=None,
-    ):
+        server_name: str,
+        limit: Optional[int] = None,
+        since_token: Optional[str] = None,
+        search_filter: Optional[dict] = None,
+        include_all_networks: bool = False,
+        third_party_instance_id: Optional[str] = None,
+    ) -> JsonDict:
         if not self.enable_room_list_search:
             return {"chunk": [], "total_room_count_estimate": 0}
 
@@ -399,13 +407,13 @@ class RoomListHandler(BaseHandler):
 
     async def _get_remote_list_cached(
         self,
-        server_name,
-        limit=None,
-        since_token=None,
-        search_filter=None,
-        include_all_networks=False,
-        third_party_instance_id=None,
-    ):
+        server_name: str,
+        limit: Optional[int] = None,
+        since_token: Optional[str] = None,
+        search_filter: Optional[dict] = None,
+        include_all_networks: bool = False,
+        third_party_instance_id: Optional[str] = None,
+    ) -> JsonDict:
         repl_layer = self.hs.get_federation_client()
         if search_filter:
             # We can't cache when asking for search
@@ -456,24 +464,24 @@ class RoomListNextBatch(
     REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()}
 
     @classmethod
-    def from_token(cls, token):
+    def from_token(cls, token: str) -> "RoomListNextBatch":
         decoded = msgpack.loads(decode_base64(token), raw=False)
         return RoomListNextBatch(
             **{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()}
         )
 
-    def to_token(self):
+    def to_token(self) -> str:
         return encode_base64(
             msgpack.dumps(
                 {self.KEY_DICT[key]: val for key, val in self._asdict().items()}
             )
         )
 
-    def copy_and_replace(self, **kwds):
+    def copy_and_replace(self, **kwds) -> "RoomListNextBatch":
         return self._replace(**kwds)
 
 
-def _matches_room_entry(room_entry, search_filter):
+def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool:
     if search_filter and search_filter.get("generic_search_term", None):
         generic_search_term = search_filter["generic_search_term"].upper()
         if generic_search_term in room_entry.get("name", "").upper():
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 339bd691a4..e96a8b3f43 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -14,11 +14,12 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Union
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
+from synapse.types import UserID
 from synapse.util.caches.lrucache import LruCache
 
 logger = logging.getLogger(__name__)
@@ -546,7 +547,9 @@ class ClientIpStore(ClientIpWorkerStore):
                     }
         return ret
 
-    async def get_user_ip_and_agents(self, user):
+    async def get_user_ip_and_agents(
+        self, user: UserID
+    ) -> List[Dict[str, Union[str, int]]]:
         user_id = user.to_string()
         results = {}