summary refs log tree commit diff
path: root/synapse/storage/databases/main/stream.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/stream.py')
-rw-r--r--synapse/storage/databases/main/stream.py82
1 files changed, 35 insertions, 47 deletions
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index db20a3db30..92e96468b4 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -79,8 +79,8 @@ _EventDictReturn = namedtuple(
 def generate_pagination_where_clause(
     direction: str,
     column_names: Tuple[str, str],
-    from_token: Optional[Tuple[int, int]],
-    to_token: Optional[Tuple[int, int]],
+    from_token: Optional[Tuple[Optional[int], int]],
+    to_token: Optional[Tuple[Optional[int], int]],
     engine: BaseDatabaseEngine,
 ) -> str:
     """Creates an SQL expression to bound the columns by the pagination
@@ -259,16 +259,14 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
     return " AND ".join(clauses), args
 
 
-class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
+class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
     """This is an abstract base class where subclasses must implement
     `get_room_max_stream_ordering` and `get_room_min_stream_ordering`
     which can be called in the initializer.
     """
 
-    __metaclass__ = abc.ABCMeta
-
     def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
-        super(StreamWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
         self._send_federation = hs.should_send_federation()
@@ -310,11 +308,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
     async def get_room_events_stream_for_rooms(
         self,
         room_ids: Collection[str],
-        from_key: str,
-        to_key: str,
+        from_key: RoomStreamToken,
+        to_key: RoomStreamToken,
         limit: int = 0,
         order: str = "DESC",
-    ) -> Dict[str, Tuple[List[EventBase], str]]:
+    ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]:
         """Get new room events in stream ordering since `from_key`.
 
         Args:
@@ -333,9 +331,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 - list of recent events in the room
                 - stream ordering key for the start of the chunk of events returned.
         """
-        from_id = RoomStreamToken.parse_stream_token(from_key).stream
-
-        room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id)
+        room_ids = self._events_stream_cache.get_entities_changed(
+            room_ids, from_key.stream
+        )
 
         if not room_ids:
             return {}
@@ -364,16 +362,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return results
 
     def get_rooms_that_changed(
-        self, room_ids: Collection[str], from_key: str
+        self, room_ids: Collection[str], from_key: RoomStreamToken
     ) -> Set[str]:
         """Given a list of rooms and a token, return rooms where there may have
         been changes.
-
-        Args:
-            room_ids
-            from_key: The room_key portion of a StreamToken
         """
-        from_id = RoomStreamToken.parse_stream_token(from_key).stream
+        from_id = from_key.stream
         return {
             room_id
             for room_id in room_ids
@@ -383,11 +377,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
     async def get_room_events_stream_for_room(
         self,
         room_id: str,
-        from_key: str,
-        to_key: str,
+        from_key: RoomStreamToken,
+        to_key: RoomStreamToken,
         limit: int = 0,
         order: str = "DESC",
-    ) -> Tuple[List[EventBase], str]:
+    ) -> Tuple[List[EventBase], RoomStreamToken]:
         """Get new room events in stream ordering since `from_key`.
 
         Args:
@@ -408,8 +402,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         if from_key == to_key:
             return [], from_key
 
-        from_id = RoomStreamToken.parse_stream_token(from_key).stream
-        to_id = RoomStreamToken.parse_stream_token(to_key).stream
+        from_id = from_key.stream
+        to_id = to_key.stream
 
         has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
 
@@ -441,7 +435,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             ret.reverse()
 
         if rows:
-            key = "s%d" % min(r.stream_ordering for r in rows)
+            key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
         else:
             # Assume we didn't get anything because there was nothing to
             # get.
@@ -450,10 +444,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return ret, key
 
     async def get_membership_changes_for_user(
-        self, user_id: str, from_key: str, to_key: str
+        self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
     ) -> List[EventBase]:
-        from_id = RoomStreamToken.parse_stream_token(from_key).stream
-        to_id = RoomStreamToken.parse_stream_token(to_key).stream
+        from_id = from_key.stream
+        to_id = to_key.stream
 
         if from_key == to_key:
             return []
@@ -491,8 +485,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return ret
 
     async def get_recent_events_for_room(
-        self, room_id: str, limit: int, end_token: str
-    ) -> Tuple[List[EventBase], str]:
+        self, room_id: str, limit: int, end_token: RoomStreamToken
+    ) -> Tuple[List[EventBase], RoomStreamToken]:
         """Get the most recent events in the room in topological ordering.
 
         Args:
@@ -518,8 +512,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return (events, token)
 
     async def get_recent_event_ids_for_room(
-        self, room_id: str, limit: int, end_token: str
-    ) -> Tuple[List[_EventDictReturn], str]:
+        self, room_id: str, limit: int, end_token: RoomStreamToken
+    ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
         """Get the most recent events in the room in topological ordering.
 
         Args:
@@ -535,8 +529,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         if limit == 0:
             return [], end_token
 
-        end_token = RoomStreamToken.parse(end_token)
-
         rows, token = await self.db_pool.runInteraction(
             "get_recent_event_ids_for_room",
             self._paginate_room_events_txn,
@@ -619,17 +611,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             allow_none=allow_none,
         )
 
-    async def get_stream_token_for_event(self, event_id: str) -> str:
+    async def get_stream_token_for_event(self, event_id: str) -> RoomStreamToken:
         """The stream token for an event
         Args:
             event_id: The id of the event to look up a stream token for.
         Raises:
             StoreError if the event wasn't in the database.
         Returns:
-            A "s%d" stream token.
+            A stream token.
         """
         stream_id = await self.get_stream_id_for_event(event_id)
-        return "s%d" % (stream_id,)
+        return RoomStreamToken(None, stream_id)
 
     async def get_topological_token_for_event(self, event_id: str) -> str:
         """The stream token for an event
@@ -951,7 +943,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         direction: str = "b",
         limit: int = -1,
         event_filter: Optional[Filter] = None,
-    ) -> Tuple[List[_EventDictReturn], str]:
+    ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
         """Returns list of events before or after a given token.
 
         Args:
@@ -986,8 +978,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         bounds = generate_pagination_where_clause(
             direction=direction,
             column_names=("topological_ordering", "stream_ordering"),
-            from_token=from_token,
-            to_token=to_token,
+            from_token=from_token.as_tuple(),
+            to_token=to_token.as_tuple() if to_token else None,
             engine=self.database_engine,
         )
 
@@ -1051,17 +1043,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             # TODO (erikj): We should work out what to do here instead.
             next_token = to_token if to_token else from_token
 
-        return rows, str(next_token)
+        return rows, next_token
 
     async def paginate_room_events(
         self,
         room_id: str,
-        from_key: str,
-        to_key: Optional[str] = None,
+        from_key: RoomStreamToken,
+        to_key: Optional[RoomStreamToken] = None,
         direction: str = "b",
         limit: int = -1,
         event_filter: Optional[Filter] = None,
-    ) -> Tuple[List[EventBase], str]:
+    ) -> Tuple[List[EventBase], RoomStreamToken]:
         """Returns list of events before or after a given token.
 
         Args:
@@ -1080,10 +1072,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             and `to_key`).
         """
 
-        from_key = RoomStreamToken.parse(from_key)
-        if to_key:
-            to_key = RoomStreamToken.parse(to_key)
-
         rows, token = await self.db_pool.runInteraction(
             "paginate_room_events",
             self._paginate_room_events_txn,