summary refs log tree commit diff
path: root/synapse/handlers/pagination.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/handlers/pagination.py56
1 files changed, 32 insertions, 24 deletions
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 5a1aa7d830..63d7edff87 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Dict, Optional
+from typing import TYPE_CHECKING, Any, Dict, Optional, Set
 
 from twisted.python.failure import Failure
 
@@ -30,6 +30,10 @@ from synapse.util.async_helpers import ReadWriteLock
 from synapse.util.stringutils import random_string
 from synapse.visibility import filter_events_for_client
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
 logger = logging.getLogger(__name__)
 
 
@@ -68,7 +72,7 @@ class PaginationHandler(object):
     paginating during a purge.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
@@ -78,9 +82,9 @@ class PaginationHandler(object):
         self._server_name = hs.hostname
 
         self.pagination_lock = ReadWriteLock()
-        self._purges_in_progress_by_room = set()
+        self._purges_in_progress_by_room = set()  # type: Set[str]
         # map from purge id to PurgeStatus
-        self._purges_by_id = {}
+        self._purges_by_id = {}  # type: Dict[str, PurgeStatus]
         self._event_serializer = hs.get_event_client_serializer()
 
         self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
@@ -102,7 +106,9 @@ class PaginationHandler(object):
                     job["longest_max_lifetime"],
                 )
 
-    async def purge_history_for_rooms_in_range(self, min_ms, max_ms):
+    async def purge_history_for_rooms_in_range(
+        self, min_ms: Optional[int], max_ms: Optional[int]
+    ):
         """Purge outdated events from rooms within the given retention range.
 
         If a default retention policy is defined in the server's configuration and its
@@ -110,10 +116,10 @@ class PaginationHandler(object):
         retention policy.
 
         Args:
-            min_ms (int|None): Duration in milliseconds that define the lower limit of
+            min_ms: Duration in milliseconds that define the lower limit of
                 the range to handle (exclusive). If None, it means that the range has no
                 lower limit.
-            max_ms (int|None): Duration in milliseconds that define the upper limit of
+            max_ms: Duration in milliseconds that define the upper limit of
                 the range to handle (inclusive). If None, it means that the range has no
                 upper limit.
         """
@@ -220,18 +226,19 @@ class PaginationHandler(object):
                 "_purge_history", self._purge_history, purge_id, room_id, token, True,
             )
 
-    def start_purge_history(self, room_id, token, delete_local_events=False):
+    def start_purge_history(
+        self, room_id: str, token: str, delete_local_events: bool = False
+    ) -> str:
         """Start off a history purge on a room.
 
         Args:
-            room_id (str): The room to purge from
-
-            token (str): topological token to delete events before
-            delete_local_events (bool): True to delete local events as well as
+            room_id: The room to purge from
+            token: topological token to delete events before
+            delete_local_events: True to delete local events as well as
                 remote ones
 
         Returns:
-            str: unique ID for this purge transaction.
+            unique ID for this purge transaction.
         """
         if room_id in self._purges_in_progress_by_room:
             raise SynapseError(
@@ -284,14 +291,11 @@ class PaginationHandler(object):
 
             self.hs.get_reactor().callLater(24 * 3600, clear_purge)
 
-    def get_purge_status(self, purge_id):
+    def get_purge_status(self, purge_id: str) -> Optional[PurgeStatus]:
         """Get the current status of an active purge
 
         Args:
-            purge_id (str): purge_id returned by start_purge_history
-
-        Returns:
-            PurgeStatus|None
+            purge_id: purge_id returned by start_purge_history
         """
         return self._purges_by_id.get(purge_id)
 
@@ -312,8 +316,8 @@ class PaginationHandler(object):
     async def get_messages(
         self,
         requester: Requester,
-        room_id: Optional[str] = None,
-        pagin_config: Optional[PaginationConfig] = None,
+        room_id: str,
+        pagin_config: PaginationConfig,
         as_client_event: bool = True,
         event_filter: Optional[Filter] = None,
     ) -> Dict[str, Any]:
@@ -368,11 +372,15 @@ class PaginationHandler(object):
                     # If they have left the room then clamp the token to be before
                     # they left the room, to save the effort of loading from the
                     # database.
+
+                    # This is only None if the room is world_readable, in which
+                    # case "JOIN" would have been returned.
+                    assert member_event_id
+
                     leave_token = await self.store.get_topological_token_for_event(
                         member_event_id
                     )
-                    leave_token = RoomStreamToken.parse(leave_token)
-                    if leave_token.topological < max_topo:
+                    if RoomStreamToken.parse(leave_token).topological < max_topo:
                         source_config.from_key = str(leave_token)
 
                 await self.hs.get_handlers().federation_handler.maybe_backfill(
@@ -419,8 +427,8 @@ class PaginationHandler(object):
             )
 
             if state_ids:
-                state = await self.store.get_events(list(state_ids.values()))
-                state = state.values()
+                state_dict = await self.store.get_events(list(state_ids.values()))
+                state = state_dict.values()
 
         time_now = self.clock.time_msec()