summary refs log tree commit diff
path: root/synapse/handlers/initial_sync.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/initial_sync.py')
-rw-r--r--synapse/handlers/initial_sync.py80
1 files changed, 48 insertions, 32 deletions
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index ae6bd1d352..d5ddc583ad 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 from twisted.internet import defer
 
@@ -22,8 +23,9 @@ from synapse.api.errors import SynapseError
 from synapse.events.validator import EventValidator
 from synapse.handlers.presence import format_user_presence_state
 from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.storage.roommember import RoomsForUser
 from synapse.streams.config import PaginationConfig
-from synapse.types import StreamToken, UserID
+from synapse.types import JsonDict, Requester, StreamToken, UserID
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import concurrently_execute
 from synapse.util.caches.response_cache import ResponseCache
@@ -31,11 +33,15 @@ from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
 logger = logging.getLogger(__name__)
 
 
 class InitialSyncHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super(InitialSyncHandler, self).__init__(hs)
         self.hs = hs
         self.state = hs.get_state_handler()
@@ -48,27 +54,25 @@ class InitialSyncHandler(BaseHandler):
 
     def snapshot_all_rooms(
         self,
-        user_id=None,
-        pagin_config=None,
-        as_client_event=True,
-        include_archived=False,
-    ):
+        user_id: str,
+        pagin_config: PaginationConfig,
+        as_client_event: bool = True,
+        include_archived: bool = False,
+    ) -> JsonDict:
         """Retrieve a snapshot of all rooms the user is invited or has joined.
 
         This snapshot may include messages for all rooms where the user is
         joined, depending on the pagination config.
 
         Args:
-            user_id (str): The ID of the user making the request.
-            pagin_config (synapse.api.streams.PaginationConfig): The pagination
-            config used to determine how many messages *PER ROOM* to return.
-            as_client_event (bool): True to get events in client-server format.
-            include_archived (bool): True to get rooms that the user has left
+            user_id: The ID of the user making the request.
+            pagin_config: The pagination config used to determine how many
+                messages *PER ROOM* to return.
+            as_client_event: True to get events in client-server format.
+            include_archived: True to get rooms that the user has left
         Returns:
-            A list of dicts with "room_id" and "membership" keys for all rooms
-            the user is currently invited or joined in on. Rooms where the user
-            is joined on, may return a "messages" key with messages, depending
-            on the specified PaginationConfig.
+            A JsonDict with the same format as the response to `/intialSync`
+            API
         """
         key = (
             user_id,
@@ -91,11 +95,11 @@ class InitialSyncHandler(BaseHandler):
 
     async def _snapshot_all_rooms(
         self,
-        user_id=None,
-        pagin_config=None,
-        as_client_event=True,
-        include_archived=False,
-    ):
+        user_id: str,
+        pagin_config: PaginationConfig,
+        as_client_event: bool = True,
+        include_archived: bool = False,
+    ) -> JsonDict:
 
         memberships = [Membership.INVITE, Membership.JOIN]
         if include_archived:
@@ -134,7 +138,7 @@ class InitialSyncHandler(BaseHandler):
         if limit is None:
             limit = 10
 
-        async def handle_room(event):
+        async def handle_room(event: RoomsForUser):
             d = {
                 "room_id": event.room_id,
                 "membership": event.membership,
@@ -251,17 +255,18 @@ class InitialSyncHandler(BaseHandler):
 
         return ret
 
-    async def room_initial_sync(self, requester, room_id, pagin_config=None):
+    async def room_initial_sync(
+        self, requester: Requester, room_id: str, pagin_config: PaginationConfig
+    ) -> JsonDict:
         """Capture the a snapshot of a room. If user is currently a member of
         the room this will be what is currently in the room. If the user left
         the room this will be what was in the room when they left.
 
         Args:
-            requester(Requester): The user to get a snapshot for.
-            room_id(str): The room to get a snapshot of.
-            pagin_config(synapse.streams.config.PaginationConfig):
-                The pagination config used to determine how many messages to
-                return.
+            requester: The user to get a snapshot for.
+            room_id: The room to get a snapshot of.
+            pagin_config: The pagination config used to determine how many
+                messages to return.
         Raises:
             AuthError if the user wasn't in the room.
         Returns:
@@ -305,8 +310,14 @@ class InitialSyncHandler(BaseHandler):
         return result
 
     async def _room_initial_sync_parted(
-        self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
-    ):
+        self,
+        user_id: str,
+        room_id: str,
+        pagin_config: PaginationConfig,
+        membership: Membership,
+        member_event_id: str,
+        is_peeking: bool,
+    ) -> JsonDict:
         room_state = await self.state_store.get_state_for_events([member_event_id])
 
         room_state = room_state[member_event_id]
@@ -350,8 +361,13 @@ class InitialSyncHandler(BaseHandler):
         }
 
     async def _room_initial_sync_joined(
-        self, user_id, room_id, pagin_config, membership, is_peeking
-    ):
+        self,
+        user_id: str,
+        room_id: str,
+        pagin_config: PaginationConfig,
+        membership: Membership,
+        is_peeking: bool,
+    ) -> JsonDict:
         current_state = await self.state.get_current_state(room_id=room_id)
 
         # TODO: These concurrently