summary refs log tree commit diff
path: root/synapse/storage/databases/main/stream.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-09-08 16:48:15 +0100
committerGitHub <noreply@github.com>2020-09-08 16:48:15 +0100
commit63c0e9e1954fc7fc10a2575c54aecc8944de60f3 (patch)
tree29b8c5045ba97f23e6ff4400654afebfe42779fb /synapse/storage/databases/main/stream.py
parentAdd a config option for validating 'next_link' parameters against a domain wh... (diff)
downloadsynapse-63c0e9e1954fc7fc10a2575c54aecc8944de60f3.tar.xz
Add types to StreamToken and RoomStreamToken (#8279)
The intention here is to change `StreamToken.room_key` to be a `RoomStreamToken` in a future PR, but that is a big enough change without this refactoring too.
Diffstat (limited to 'synapse/storage/databases/main/stream.py')
-rw-r--r--synapse/storage/databases/main/stream.py21
1 files changed, 11 insertions, 10 deletions
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index be6df8a6d1..08a13a8b47 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -79,8 +79,8 @@ _EventDictReturn = namedtuple(
 def generate_pagination_where_clause(
     direction: str,
     column_names: Tuple[str, str],
-    from_token: Optional[Tuple[int, int]],
-    to_token: Optional[Tuple[int, int]],
+    from_token: Optional[Tuple[Optional[int], int]],
+    to_token: Optional[Tuple[Optional[int], int]],
     engine: BaseDatabaseEngine,
 ) -> str:
     """Creates an SQL expression to bound the columns by the pagination
@@ -535,13 +535,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         if limit == 0:
             return [], end_token
 
-        end_token = RoomStreamToken.parse(end_token)
+        parsed_end_token = RoomStreamToken.parse(end_token)
 
         rows, token = await self.db_pool.runInteraction(
             "get_recent_event_ids_for_room",
             self._paginate_room_events_txn,
             room_id,
-            from_token=end_token,
+            from_token=parsed_end_token,
             limit=limit,
         )
 
@@ -989,8 +989,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         bounds = generate_pagination_where_clause(
             direction=direction,
             column_names=("topological_ordering", "stream_ordering"),
-            from_token=from_token,
-            to_token=to_token,
+            from_token=from_token.as_tuple(),
+            to_token=to_token.as_tuple() if to_token else None,
             engine=self.database_engine,
         )
 
@@ -1083,16 +1083,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             and `to_key`).
         """
 
-        from_key = RoomStreamToken.parse(from_key)
+        parsed_from_key = RoomStreamToken.parse(from_key)
+        parsed_to_key = None
         if to_key:
-            to_key = RoomStreamToken.parse(to_key)
+            parsed_to_key = RoomStreamToken.parse(to_key)
 
         rows, token = await self.db_pool.runInteraction(
             "paginate_room_events",
             self._paginate_room_events_txn,
             room_id,
-            from_key,
-            to_key,
+            parsed_from_key,
+            parsed_to_key,
             direction,
             limit,
             event_filter,