summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-08-23 08:14:42 -0400
committerGitHub <noreply@github.com>2021-08-23 08:14:42 -0400
commitbd7d398b05aaa18d5b0629153ababeea7539256c (patch)
treeca108d070b09531af89ade9b64ddab3a1b0d2c10 /synapse
parentAddtional type hints for the REST servlets. (#10665) (diff)
downloadsynapse-bd7d398b05aaa18d5b0629153ababeea7539256c.tar.xz
Additional type hints for the sync REST servlet. (#10666)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/sync.py21
-rw-r--r--synapse/rest/client/sync.py132
2 files changed, 92 insertions, 61 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 2203c45dcc..86c3c7f0df 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -30,6 +30,7 @@ from prometheus_client import Counter
 
 from synapse.api.constants import AccountDataTypes, EventTypes, Membership
 from synapse.api.filtering import FilterCollection
+from synapse.api.presence import UserPresenceState
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
 from synapse.logging.context import current_context
@@ -231,7 +232,7 @@ class SyncResult:
     """
 
     next_batch: StreamToken
-    presence: List[JsonDict]
+    presence: List[UserPresenceState]
     account_data: List[JsonDict]
     joined: List[JoinedSyncResult]
     invited: List[InvitedSyncResult]
@@ -2177,14 +2178,14 @@ class SyncResultBuilder:
         joined_room_ids: List of rooms the user is joined to
 
         # The following mirror the fields in a sync response
-        presence (list)
-        account_data (list)
-        joined (list[JoinedSyncResult])
-        invited (list[InvitedSyncResult])
-        knocked (list[KnockedSyncResult])
-        archived (list[ArchivedSyncResult])
-        groups (GroupsSyncResult|None)
-        to_device (list)
+        presence
+        account_data
+        joined
+        invited
+        knocked
+        archived
+        groups
+        to_device
     """
 
     sync_config: SyncConfig
@@ -2193,7 +2194,7 @@ class SyncResultBuilder:
     now_token: StreamToken
     joined_room_ids: FrozenSet[str]
 
-    presence: List[JsonDict] = attr.Factory(list)
+    presence: List[UserPresenceState] = attr.Factory(list)
     account_data: List[JsonDict] = attr.Factory(list)
     joined: List[JoinedSyncResult] = attr.Factory(list)
     invited: List[InvitedSyncResult] = attr.Factory(list)
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index e18f4d01b3..65c37be3e9 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -14,17 +14,26 @@
 import itertools
 import logging
 from collections import defaultdict
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
 
 from synapse.api.constants import Membership, PresenceState
 from synapse.api.errors import Codes, StoreError, SynapseError
 from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
+from synapse.api.presence import UserPresenceState
 from synapse.events.utils import (
     format_event_for_client_v2_without_room_id,
     format_event_raw,
 )
 from synapse.handlers.presence import format_user_presence_state
-from synapse.handlers.sync import KnockedSyncResult, SyncConfig
+from synapse.handlers.sync import (
+    ArchivedSyncResult,
+    InvitedSyncResult,
+    JoinedSyncResult,
+    KnockedSyncResult,
+    SyncConfig,
+    SyncResult,
+)
+from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
 from synapse.types import JsonDict, StreamToken
@@ -192,6 +201,8 @@ class SyncRestServlet(RestServlet):
             return 200, {}
 
         time_now = self.clock.time_msec()
+        # We know that the the requester has an access token since appservices
+        # cannot use sync.
         response_content = await self.encode_response(
             time_now, sync_result, requester.access_token_id, filter_collection
         )
@@ -199,7 +210,13 @@ class SyncRestServlet(RestServlet):
         logger.debug("Event formatting complete")
         return 200, response_content
 
-    async def encode_response(self, time_now, sync_result, access_token_id, filter):
+    async def encode_response(
+        self,
+        time_now: int,
+        sync_result: SyncResult,
+        access_token_id: Optional[int],
+        filter: FilterCollection,
+    ) -> JsonDict:
         logger.debug("Formatting events in sync response")
         if filter.event_format == "client":
             event_formatter = format_event_for_client_v2_without_room_id
@@ -234,7 +251,7 @@ class SyncRestServlet(RestServlet):
 
         logger.debug("building sync response dict")
 
-        response: dict = defaultdict(dict)
+        response: JsonDict = defaultdict(dict)
         response["next_batch"] = await sync_result.next_batch.to_string(self.store)
 
         if sync_result.account_data:
@@ -274,6 +291,8 @@ class SyncRestServlet(RestServlet):
         if archived:
             response["rooms"][Membership.LEAVE] = archived
 
+        # By the time we get here groups is no longer optional.
+        assert sync_result.groups is not None
         if sync_result.groups.join:
             response["groups"][Membership.JOIN] = sync_result.groups.join
         if sync_result.groups.invite:
@@ -284,7 +303,7 @@ class SyncRestServlet(RestServlet):
         return response
 
     @staticmethod
-    def encode_presence(events, time_now):
+    def encode_presence(events: List[UserPresenceState], time_now: int) -> JsonDict:
         return {
             "events": [
                 {
@@ -299,25 +318,27 @@ class SyncRestServlet(RestServlet):
         }
 
     async def encode_joined(
-        self, rooms, time_now, token_id, event_fields, event_formatter
-    ):
+        self,
+        rooms: List[JoinedSyncResult],
+        time_now: int,
+        token_id: Optional[int],
+        event_fields: List[str],
+        event_formatter: Callable[[JsonDict], JsonDict],
+    ) -> JsonDict:
         """
         Encode the joined rooms in a sync result
 
         Args:
-            rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync
-                results for rooms this user is joined to
-            time_now(int): current time - used as a baseline for age
-                calculations
-            token_id(int): ID of the user's auth token - used for namespacing
+            rooms: list of sync results for rooms this user is joined to
+            time_now: current time - used as a baseline for age calculations
+            token_id: ID of the user's auth token - used for namespacing
                 of transaction IDs
-            event_fields(list<str>): List of event fields to include. If empty,
+            event_fields: List of event fields to include. If empty,
                 all fields will be returned.
-            event_formatter (func[dict]): function to convert from federation format
+            event_formatter: function to convert from federation format
                 to client format
         Returns:
-            dict[str, dict[str, object]]: the joined rooms list, in our
-                response format
+            The joined rooms list, in our response format
         """
         joined = {}
         for room in rooms:
@@ -332,23 +353,26 @@ class SyncRestServlet(RestServlet):
 
         return joined
 
-    async def encode_invited(self, rooms, time_now, token_id, event_formatter):
+    async def encode_invited(
+        self,
+        rooms: List[InvitedSyncResult],
+        time_now: int,
+        token_id: Optional[int],
+        event_formatter: Callable[[JsonDict], JsonDict],
+    ) -> JsonDict:
         """
         Encode the invited rooms in a sync result
 
         Args:
-            rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of
-                sync results for rooms this user is invited to
-            time_now(int): current time - used as a baseline for age
-                calculations
-            token_id(int): ID of the user's auth token - used for namespacing
+            rooms: list of sync results for rooms this user is invited to
+            time_now: current time - used as a baseline for age calculations
+            token_id: ID of the user's auth token - used for namespacing
                 of transaction IDs
-            event_formatter (func[dict]): function to convert from federation format
+            event_formatter: function to convert from federation format
                 to client format
 
         Returns:
-            dict[str, dict[str, object]]: the invited rooms list, in our
-                response format
+            The invited rooms list, in our response format
         """
         invited = {}
         for room in rooms:
@@ -371,7 +395,7 @@ class SyncRestServlet(RestServlet):
         self,
         rooms: List[KnockedSyncResult],
         time_now: int,
-        token_id: int,
+        token_id: Optional[int],
         event_formatter: Callable[[Dict], Dict],
     ) -> Dict[str, Dict[str, Any]]:
         """
@@ -422,25 +446,26 @@ class SyncRestServlet(RestServlet):
         return knocked
 
     async def encode_archived(
-        self, rooms, time_now, token_id, event_fields, event_formatter
-    ):
+        self,
+        rooms: List[ArchivedSyncResult],
+        time_now: int,
+        token_id: Optional[int],
+        event_fields: List[str],
+        event_formatter: Callable[[JsonDict], JsonDict],
+    ) -> JsonDict:
         """
         Encode the archived rooms in a sync result
 
         Args:
-            rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of
-                sync results for rooms this user is joined to
-            time_now(int): current time - used as a baseline for age
-                calculations
-            token_id(int): ID of the user's auth token - used for namespacing
+            rooms: list of sync results for rooms this user is joined to
+            time_now: current time - used as a baseline for age calculations
+            token_id: ID of the user's auth token - used for namespacing
                 of transaction IDs
-            event_fields(list<str>): List of event fields to include. If empty,
+            event_fields: List of event fields to include. If empty,
                 all fields will be returned.
-            event_formatter (func[dict]): function to convert from federation format
-                to client format
+            event_formatter: function to convert from federation format to client format
         Returns:
-            dict[str, dict[str, object]]: The invited rooms list, in our
-                response format
+            The archived rooms list, in our response format
         """
         joined = {}
         for room in rooms:
@@ -456,23 +481,27 @@ class SyncRestServlet(RestServlet):
         return joined
 
     async def encode_room(
-        self, room, time_now, token_id, joined, only_fields, event_formatter
-    ):
+        self,
+        room: Union[JoinedSyncResult, ArchivedSyncResult],
+        time_now: int,
+        token_id: Optional[int],
+        joined: bool,
+        only_fields: Optional[List[str]],
+        event_formatter: Callable[[JsonDict], JsonDict],
+    ) -> JsonDict:
         """
         Args:
-            room (JoinedSyncResult|ArchivedSyncResult): sync result for a
-                single room
-            time_now (int): current time - used as a baseline for age
-                calculations
-            token_id (int): ID of the user's auth token - used for namespacing
+            room: sync result for a single room
+            time_now: current time - used as a baseline for age calculations
+            token_id: ID of the user's auth token - used for namespacing
                 of transaction IDs
-            joined (bool): True if the user is joined to this room - will mean
+            joined: True if the user is joined to this room - will mean
                 we handle ephemeral events
-            only_fields(list<str>): Optional. The list of event fields to include.
-            event_formatter (func[dict]): function to convert from federation format
+            only_fields: Optional. The list of event fields to include.
+            event_formatter: function to convert from federation format
                 to client format
         Returns:
-            dict[str, object]: the room, encoded in our response format
+            The room, encoded in our response format
         """
 
         def serialize(events):
@@ -508,7 +537,7 @@ class SyncRestServlet(RestServlet):
 
         account_data = room.account_data
 
-        result = {
+        result: JsonDict = {
             "timeline": {
                 "events": serialized_timeline,
                 "prev_batch": await room.timeline.prev_batch.to_string(self.store),
@@ -519,6 +548,7 @@ class SyncRestServlet(RestServlet):
         }
 
         if joined:
+            assert isinstance(room, JoinedSyncResult)
             ephemeral_events = room.ephemeral
             result["ephemeral"] = {"events": ephemeral_events}
             result["unread_notifications"] = room.unread_notifications
@@ -528,5 +558,5 @@ class SyncRestServlet(RestServlet):
         return result
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     SyncRestServlet(hs).register(http_server)