summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-09-08 16:48:15 +0100
committerGitHub <noreply@github.com>2020-09-08 16:48:15 +0100
commit63c0e9e1954fc7fc10a2575c54aecc8944de60f3 (patch)
tree29b8c5045ba97f23e6ff4400654afebfe42779fb /synapse/storage/databases/main
parentAdd a config option for validating 'next_link' parameters against a domain wh... (diff)
downloadsynapse-63c0e9e1954fc7fc10a2575c54aecc8944de60f3.tar.xz
Add types to StreamToken and RoomStreamToken (#8279)
The intention here is to change `StreamToken.room_key` to be a `RoomStreamToken` in a future PR, but that is a big enough change without this refactoring too.
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/devices.py7
-rw-r--r--synapse/storage/databases/main/stream.py21
2 files changed, 14 insertions, 14 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index add4e3ea0e..306fc6947c 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -481,7 +481,7 @@ class DeviceWorkerStore(SQLBaseStore):
         }
 
     async def get_users_whose_devices_changed(
-        self, from_key: str, user_ids: Iterable[str]
+        self, from_key: int, user_ids: Iterable[str]
     ) -> Set[str]:
         """Get set of users whose devices have changed since `from_key` that
         are in the given list of user_ids.
@@ -493,7 +493,6 @@ class DeviceWorkerStore(SQLBaseStore):
         Returns:
             The set of user_ids whose devices have changed since `from_key`
         """
-        from_key = int(from_key)
 
         # Get set of users who *may* have changed. Users not in the returned
         # list have definitely not changed.
@@ -527,7 +526,7 @@ class DeviceWorkerStore(SQLBaseStore):
         )
 
     async def get_users_whose_signatures_changed(
-        self, user_id: str, from_key: str
+        self, user_id: str, from_key: int
     ) -> Set[str]:
         """Get the users who have new cross-signing signatures made by `user_id` since
         `from_key`.
@@ -539,7 +538,7 @@ class DeviceWorkerStore(SQLBaseStore):
         Returns:
             A set of user IDs with updated signatures.
         """
-        from_key = int(from_key)
+
         if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
             sql = """
                 SELECT DISTINCT user_ids FROM user_signature_stream
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index be6df8a6d1..08a13a8b47 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -79,8 +79,8 @@ _EventDictReturn = namedtuple(
 def generate_pagination_where_clause(
     direction: str,
     column_names: Tuple[str, str],
-    from_token: Optional[Tuple[int, int]],
-    to_token: Optional[Tuple[int, int]],
+    from_token: Optional[Tuple[Optional[int], int]],
+    to_token: Optional[Tuple[Optional[int], int]],
     engine: BaseDatabaseEngine,
 ) -> str:
     """Creates an SQL expression to bound the columns by the pagination
@@ -535,13 +535,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         if limit == 0:
             return [], end_token
 
-        end_token = RoomStreamToken.parse(end_token)
+        parsed_end_token = RoomStreamToken.parse(end_token)
 
         rows, token = await self.db_pool.runInteraction(
             "get_recent_event_ids_for_room",
             self._paginate_room_events_txn,
             room_id,
-            from_token=end_token,
+            from_token=parsed_end_token,
             limit=limit,
         )
 
@@ -989,8 +989,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         bounds = generate_pagination_where_clause(
             direction=direction,
             column_names=("topological_ordering", "stream_ordering"),
-            from_token=from_token,
-            to_token=to_token,
+            from_token=from_token.as_tuple(),
+            to_token=to_token.as_tuple() if to_token else None,
             engine=self.database_engine,
         )
 
@@ -1083,16 +1083,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             and `to_key`).
         """
 
-        from_key = RoomStreamToken.parse(from_key)
+        parsed_from_key = RoomStreamToken.parse(from_key)
+        parsed_to_key = None
         if to_key:
-            to_key = RoomStreamToken.parse(to_key)
+            parsed_to_key = RoomStreamToken.parse(to_key)
 
         rows, token = await self.db_pool.runInteraction(
             "paginate_room_events",
             self._paginate_room_events_txn,
             room_id,
-            from_key,
-            to_key,
+            parsed_from_key,
+            parsed_to_key,
             direction,
             limit,
             event_filter,