summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-09-11 12:22:55 +0100
committerGitHub <noreply@github.com>2020-09-11 12:22:55 +0100
commitfe8ed1b46f781faa45d1bba8f9308cf47c42010f (patch)
tree10d5b3cb181a70bd690a6e53461db5de394d9a4b /synapse/storage/databases/main
parentUse TLSv1.2 for fake servers in tests (#8208) (diff)
downloadsynapse-fe8ed1b46f781faa45d1bba8f9308cf47c42010f.tar.xz
Make `StreamToken.room_key` be a `RoomStreamToken` instance. (#8281)
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/events.py21
-rw-r--r--synapse/storage/databases/main/stream.py75
2 files changed, 47 insertions, 49 deletions
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index b3d27a2ee7..9cd1403b38 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -213,7 +213,7 @@ class PersistEventsStore:
         Returns:
             Filtered event ids
         """
-        results = []
+        results = []  # type: List[str]
 
         def _get_events_which_are_prevs_txn(txn, batch):
             sql = """
@@ -631,7 +631,9 @@ class PersistEventsStore:
         )
 
     @classmethod
-    def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
+    def _filter_events_and_contexts_for_duplicates(
+        cls, events_and_contexts: List[Tuple[EventBase, EventContext]]
+    ) -> List[Tuple[EventBase, EventContext]]:
         """Ensure that we don't have the same event twice.
 
         Pick the earliest non-outlier if there is one, else the earliest one.
@@ -641,7 +643,9 @@ class PersistEventsStore:
         Returns:
             list[(EventBase, EventContext)]: filtered list
         """
-        new_events_and_contexts = OrderedDict()
+        new_events_and_contexts = (
+            OrderedDict()
+        )  # type: OrderedDict[str, Tuple[EventBase, EventContext]]
         for event, context in events_and_contexts:
             prev_event_context = new_events_and_contexts.get(event.event_id)
             if prev_event_context:
@@ -655,7 +659,12 @@ class PersistEventsStore:
                 new_events_and_contexts[event.event_id] = (event, context)
         return list(new_events_and_contexts.values())
 
-    def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
+    def _update_room_depths_txn(
+        self,
+        txn,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+        backfilled: bool,
+    ):
         """Update min_depth for each room
 
         Args:
@@ -664,7 +673,7 @@ class PersistEventsStore:
                 we are persisting
             backfilled (bool): True if the events were backfilled
         """
-        depth_updates = {}
+        depth_updates = {}  # type: Dict[str, int]
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
             txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
@@ -1436,7 +1445,7 @@ class PersistEventsStore:
 
         Forward extremities are handled when we first start persisting the events.
         """
-        events_by_room = {}
+        events_by_room = {}  # type: Dict[str, List[EventBase]]
         for ev in events:
             events_by_room.setdefault(ev.room_id, []).append(ev)
 
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 08a13a8b47..2e95518752 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -310,11 +310,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
     async def get_room_events_stream_for_rooms(
         self,
         room_ids: Collection[str],
-        from_key: str,
-        to_key: str,
+        from_key: RoomStreamToken,
+        to_key: RoomStreamToken,
         limit: int = 0,
         order: str = "DESC",
-    ) -> Dict[str, Tuple[List[EventBase], str]]:
+    ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]:
         """Get new room events in stream ordering since `from_key`.
 
         Args:
@@ -333,9 +333,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 - list of recent events in the room
                 - stream ordering key for the start of the chunk of events returned.
         """
-        from_id = RoomStreamToken.parse_stream_token(from_key).stream
-
-        room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id)
+        room_ids = self._events_stream_cache.get_entities_changed(
+            room_ids, from_key.stream
+        )
 
         if not room_ids:
             return {}
@@ -364,16 +364,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return results
 
     def get_rooms_that_changed(
-        self, room_ids: Collection[str], from_key: str
+        self, room_ids: Collection[str], from_key: RoomStreamToken
     ) -> Set[str]:
         """Given a list of rooms and a token, return rooms where there may have
         been changes.
-
-        Args:
-            room_ids
-            from_key: The room_key portion of a StreamToken
         """
-        from_id = RoomStreamToken.parse_stream_token(from_key).stream
+        from_id = from_key.stream
         return {
             room_id
             for room_id in room_ids
@@ -383,11 +379,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
     async def get_room_events_stream_for_room(
         self,
         room_id: str,
-        from_key: str,
-        to_key: str,
+        from_key: RoomStreamToken,
+        to_key: RoomStreamToken,
         limit: int = 0,
         order: str = "DESC",
-    ) -> Tuple[List[EventBase], str]:
+    ) -> Tuple[List[EventBase], RoomStreamToken]:
         """Get new room events in stream ordering since `from_key`.
 
         Args:
@@ -408,8 +404,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         if from_key == to_key:
             return [], from_key
 
-        from_id = RoomStreamToken.parse_stream_token(from_key).stream
-        to_id = RoomStreamToken.parse_stream_token(to_key).stream
+        from_id = from_key.stream
+        to_id = to_key.stream
 
         has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
 
@@ -441,7 +437,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             ret.reverse()
 
         if rows:
-            key = "s%d" % min(r.stream_ordering for r in rows)
+            key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
         else:
             # Assume we didn't get anything because there was nothing to
             # get.
@@ -450,10 +446,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return ret, key
 
     async def get_membership_changes_for_user(
-        self, user_id: str, from_key: str, to_key: str
+        self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
     ) -> List[EventBase]:
-        from_id = RoomStreamToken.parse_stream_token(from_key).stream
-        to_id = RoomStreamToken.parse_stream_token(to_key).stream
+        from_id = from_key.stream
+        to_id = to_key.stream
 
         if from_key == to_key:
             return []
@@ -491,8 +487,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return ret
 
     async def get_recent_events_for_room(
-        self, room_id: str, limit: int, end_token: str
-    ) -> Tuple[List[EventBase], str]:
+        self, room_id: str, limit: int, end_token: RoomStreamToken
+    ) -> Tuple[List[EventBase], RoomStreamToken]:
         """Get the most recent events in the room in topological ordering.
 
         Args:
@@ -518,8 +514,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return (events, token)
 
     async def get_recent_event_ids_for_room(
-        self, room_id: str, limit: int, end_token: str
-    ) -> Tuple[List[_EventDictReturn], str]:
+        self, room_id: str, limit: int, end_token: RoomStreamToken
+    ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
         """Get the most recent events in the room in topological ordering.
 
         Args:
@@ -535,13 +531,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         if limit == 0:
             return [], end_token
 
-        parsed_end_token = RoomStreamToken.parse(end_token)
-
         rows, token = await self.db_pool.runInteraction(
             "get_recent_event_ids_for_room",
             self._paginate_room_events_txn,
             room_id,
-            from_token=parsed_end_token,
+            from_token=end_token,
             limit=limit,
         )
 
@@ -619,17 +613,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             allow_none=allow_none,
         )
 
-    async def get_stream_token_for_event(self, event_id: str) -> str:
+    async def get_stream_token_for_event(self, event_id: str) -> RoomStreamToken:
         """The stream token for an event
         Args:
             event_id: The id of the event to look up a stream token for.
         Raises:
             StoreError if the event wasn't in the database.
         Returns:
-            A "s%d" stream token.
+            A stream token.
         """
         stream_id = await self.get_stream_id_for_event(event_id)
-        return "s%d" % (stream_id,)
+        return RoomStreamToken(None, stream_id)
 
     async def get_topological_token_for_event(self, event_id: str) -> str:
         """The stream token for an event
@@ -954,7 +948,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         direction: str = "b",
         limit: int = -1,
         event_filter: Optional[Filter] = None,
-    ) -> Tuple[List[_EventDictReturn], str]:
+    ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
         """Returns list of events before or after a given token.
 
         Args:
@@ -1054,17 +1048,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             # TODO (erikj): We should work out what to do here instead.
             next_token = to_token if to_token else from_token
 
-        return rows, str(next_token)
+        return rows, next_token
 
     async def paginate_room_events(
         self,
         room_id: str,
-        from_key: str,
-        to_key: Optional[str] = None,
+        from_key: RoomStreamToken,
+        to_key: Optional[RoomStreamToken] = None,
         direction: str = "b",
         limit: int = -1,
         event_filter: Optional[Filter] = None,
-    ) -> Tuple[List[EventBase], str]:
+    ) -> Tuple[List[EventBase], RoomStreamToken]:
         """Returns list of events before or after a given token.
 
         Args:
@@ -1083,17 +1077,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             and `to_key`).
         """
 
-        parsed_from_key = RoomStreamToken.parse(from_key)
-        parsed_to_key = None
-        if to_key:
-            parsed_to_key = RoomStreamToken.parse(to_key)
-
         rows, token = await self.db_pool.runInteraction(
             "paginate_room_events",
             self._paginate_room_events_txn,
             room_id,
-            parsed_from_key,
-            parsed_to_key,
+            from_key,
+            to_key,
             direction,
             limit,
             event_filter,