summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/8423.misc1
-rw-r--r--synapse/events/__init__.py6
-rw-r--r--synapse/handlers/admin.py2
-rw-r--r--synapse/handlers/device.py4
-rw-r--r--synapse/handlers/initial_sync.py3
-rw-r--r--synapse/handlers/pagination.py5
-rw-r--r--synapse/handlers/room.py4
-rw-r--r--synapse/handlers/sync.py20
-rw-r--r--synapse/notifier.py4
-rw-r--r--synapse/replication/tcp/client.py6
-rw-r--r--synapse/rest/admin/__init__.py3
-rw-r--r--synapse/storage/databases/main/stream.py38
-rw-r--r--synapse/storage/persist_events.py5
-rw-r--r--synapse/types.py53
-rw-r--r--tests/rest/client/v1/test_rooms.py8
-rw-r--r--tests/storage/test_purge.py10
16 files changed, 96 insertions, 76 deletions
diff --git a/changelog.d/8423.misc b/changelog.d/8423.misc
new file mode 100644
index 0000000000..7260e3fa41
--- /dev/null
+++ b/changelog.d/8423.misc
@@ -0,0 +1 @@
+Various refactors to simplify stream token handling.
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index bf800a3852..dc49df0812 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -23,7 +23,7 @@ from typing import Dict, Optional, Tuple, Type
 from unpaddedbase64 import encode_base64
 
 from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
-from synapse.types import JsonDict
+from synapse.types import JsonDict, RoomStreamToken
 from synapse.util.caches import intern_dict
 from synapse.util.frozenutils import freeze
 
@@ -118,8 +118,8 @@ class _EventInternalMetadata:
     # XXX: These are set by StreamWorkerStore._set_before_and_after.
     # I'm pretty sure that these are never persisted to the database, so shouldn't
     # be here
-    before = DictProperty("before")  # type: str
-    after = DictProperty("after")  # type: str
+    before = DictProperty("before")  # type: RoomStreamToken
+    after = DictProperty("after")  # type: RoomStreamToken
     order = DictProperty("order")  # type: Tuple[int, int]
 
     def get_dict(self) -> JsonDict:
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index dd981c597e..1ce2091b46 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -153,7 +153,7 @@ class AdminHandler(BaseHandler):
                 if not events:
                     break
 
-                from_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
+                from_key = events[-1].internal_metadata.after
 
                 events = await filter_events_for_client(self.storage, user_id, events)
 
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 4149520d6c..b9d9098104 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -29,7 +29,6 @@ from synapse.api.errors import (
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import (
-    RoomStreamToken,
     StreamToken,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
@@ -113,8 +112,7 @@ class DeviceWorkerHandler(BaseHandler):
 
         set_tag("user_id", user_id)
         set_tag("from_token", from_token)
-        now_room_id = self.store.get_room_max_stream_ordering()
-        now_room_key = RoomStreamToken(None, now_room_id)
+        now_room_key = self.store.get_room_max_token()
 
         room_ids = await self.store.get_rooms_for_user(user_id)
 
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 8cd7eb22a3..43f15435de 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -325,7 +325,8 @@ class InitialSyncHandler(BaseHandler):
         if limit is None:
             limit = 10
 
-        stream_token = await self.store.get_stream_token_for_event(member_event_id)
+        leave_position = await self.store.get_position_for_event(member_event_id)
+        stream_token = leave_position.to_room_stream_token()
 
         messages, token = await self.store.get_recent_events_for_room(
             room_id, limit=limit, end_token=stream_token
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index a0b3bdb5e0..d6779a4b44 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -25,7 +25,7 @@ from synapse.logging.context import run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.state import StateFilter
 from synapse.streams.config import PaginationConfig
-from synapse.types import Requester, RoomStreamToken
+from synapse.types import Requester
 from synapse.util.async_helpers import ReadWriteLock
 from synapse.util.stringutils import random_string
 from synapse.visibility import filter_events_for_client
@@ -373,10 +373,9 @@ class PaginationHandler:
                     # case "JOIN" would have been returned.
                     assert member_event_id
 
-                    leave_token_str = await self.store.get_topological_token_for_event(
+                    leave_token = await self.store.get_topological_token_for_event(
                         member_event_id
                     )
-                    leave_token = RoomStreamToken.parse(leave_token_str)
                     assert leave_token.topological is not None
 
                     if leave_token.topological < curr_topo:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 11bf146bed..836b3f381a 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1134,14 +1134,14 @@ class RoomEventSource:
                 events[:] = events[:limit]
 
             if events:
-                end_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
+                end_key = events[-1].internal_metadata.after
             else:
                 end_key = to_key
 
         return (events, end_key)
 
     def get_current_key(self) -> RoomStreamToken:
-        return RoomStreamToken(None, self.store.get_room_max_stream_ordering())
+        return self.store.get_room_max_token()
 
     def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
         return self.store.get_room_events_max_id(room_id)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index e948efef2e..bfe2583002 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -519,7 +519,7 @@ class SyncHandler:
             if len(recents) > timeline_limit:
                 limited = True
                 recents = recents[-timeline_limit:]
-                room_key = RoomStreamToken.parse(recents[0].internal_metadata.before)
+                room_key = recents[0].internal_metadata.before
 
             prev_batch_token = now_token.copy_and_replace("room_key", room_key)
 
@@ -1595,16 +1595,24 @@ class SyncHandler:
 
             if leave_events:
                 leave_event = leave_events[-1]
-                leave_stream_token = await self.store.get_stream_token_for_event(
+                leave_position = await self.store.get_position_for_event(
                     leave_event.event_id
                 )
-                leave_token = since_token.copy_and_replace(
-                    "room_key", leave_stream_token
-                )
 
-                if since_token and since_token.is_after(leave_token):
+                # If the leave event happened before the since token then we
+                # bail.
+                if since_token and not leave_position.persisted_after(
+                    since_token.room_key
+                ):
                     continue
 
+                # We can safely convert the position of the leave event into a
+                # stream token as it'll only be used in the context of this
+                # room. (c.f. the docstring of `to_room_stream_token`).
+                leave_token = since_token.copy_and_replace(
+                    "room_key", leave_position.to_room_stream_token()
+                )
+
                 # If this is an out of band message, like a remote invite
                 # rejection, we include it in the recents batch. Otherwise, we
                 # let _load_filtered_recents handle fetching the correct
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 441b3d15e2..59415f6f88 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -163,7 +163,7 @@ class _NotifierUserStream:
         """
         # Immediately wake up stream if something has already since happened
         # since their last token.
-        if self.last_notified_token.is_after(token):
+        if self.last_notified_token != token:
             return _NotificationListener(defer.succeed(self.current_token))
         else:
             return _NotificationListener(self.notify_deferred.observe())
@@ -470,7 +470,7 @@ class Notifier:
         async def check_for_updates(
             before_token: StreamToken, after_token: StreamToken
         ) -> EventStreamResult:
-            if not after_token.is_after(before_token):
+            if after_token == before_token:
                 return EventStreamResult([], (from_token, from_token))
 
             events = []  # type: List[EventBase]
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 55af3d41ea..e165429cad 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -29,7 +29,7 @@ from synapse.replication.tcp.streams.events import (
     EventsStreamEventRow,
     EventsStreamRow,
 )
-from synapse.types import PersistedEventPosition, RoomStreamToken, UserID
+from synapse.types import PersistedEventPosition, UserID
 from synapse.util.async_helpers import timeout_deferred
 from synapse.util.metrics import Measure
 
@@ -152,9 +152,7 @@ class ReplicationDataHandler:
                 if event.type == EventTypes.Member:
                     extra_users = (UserID.from_string(event.state_key),)
 
-                max_token = RoomStreamToken(
-                    None, self.store.get_room_max_stream_ordering()
-                )
+                max_token = self.store.get_room_max_token()
                 event_pos = PersistedEventPosition(instance_name, token)
                 self.notifier.on_new_room_event(
                     event, event_pos, max_token, extra_users
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 5c5f00b213..ba53f66f02 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -109,7 +109,8 @@ class PurgeHistoryRestServlet(RestServlet):
             if event.room_id != room_id:
                 raise SynapseError(400, "Event is for wrong room.")
 
-            token = await self.store.get_topological_token_for_event(event_id)
+            room_token = await self.store.get_topological_token_for_event(event_id)
+            token = str(room_token)
 
             logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
         elif "purge_up_to_ts" in body:
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 92e96468b4..37249f1e3f 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -35,7 +35,6 @@ what sort order was used:
     - topological tokems: "t%d-%d", where the integers map to the topological
       and stream ordering columns respectively.
 """
-
 import abc
 import logging
 from collections import namedtuple
@@ -54,7 +53,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
-from synapse.types import Collection, RoomStreamToken
+from synapse.types import Collection, PersistedEventPosition, RoomStreamToken
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 if TYPE_CHECKING:
@@ -305,6 +304,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
     def get_room_min_stream_ordering(self) -> int:
         raise NotImplementedError()
 
+    def get_room_max_token(self) -> RoomStreamToken:
+        return RoomStreamToken(None, self.get_room_max_stream_ordering())
+
     async def get_room_events_stream_for_rooms(
         self,
         room_ids: Collection[str],
@@ -611,26 +613,28 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
             allow_none=allow_none,
         )
 
-    async def get_stream_token_for_event(self, event_id: str) -> RoomStreamToken:
-        """The stream token for an event
-        Args:
-            event_id: The id of the event to look up a stream token for.
-        Raises:
-            StoreError if the event wasn't in the database.
-        Returns:
-            A stream token.
+    async def get_position_for_event(self, event_id: str) -> PersistedEventPosition:
+        """Get the persisted position for an event
         """
-        stream_id = await self.get_stream_id_for_event(event_id)
-        return RoomStreamToken(None, stream_id)
+        row = await self.db_pool.simple_select_one(
+            table="events",
+            keyvalues={"event_id": event_id},
+            retcols=("stream_ordering", "instance_name"),
+            desc="get_position_for_event",
+        )
+
+        return PersistedEventPosition(
+            row["instance_name"] or "master", row["stream_ordering"]
+        )
 
-    async def get_topological_token_for_event(self, event_id: str) -> str:
+    async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken:
         """The stream token for an event
         Args:
             event_id: The id of the event to look up a stream token for.
         Raises:
             StoreError if the event wasn't in the database.
         Returns:
-            A "t%d-%d" topological token.
+            A `RoomStreamToken` topological token.
         """
         row = await self.db_pool.simple_select_one(
             table="events",
@@ -638,7 +642,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
             retcols=("stream_ordering", "topological_ordering"),
             desc="get_topological_token_for_event",
         )
-        return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
+        return RoomStreamToken(row["topological_ordering"], row["stream_ordering"])
 
     async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
         """Gets the topological token in a room after or at the given stream
@@ -687,8 +691,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
             else:
                 topo = None
             internal = event.internal_metadata
-            internal.before = str(RoomStreamToken(topo, stream - 1))
-            internal.after = str(RoomStreamToken(topo, stream))
+            internal.before = RoomStreamToken(topo, stream - 1)
+            internal.after = RoomStreamToken(topo, stream)
             internal.order = (int(topo) if topo else 0, int(stream))
 
     async def get_events_around(
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index ded6cf9655..72939f3984 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -229,7 +229,7 @@ class EventsPersistenceStorage:
             defer.gatherResults(deferreds, consumeErrors=True)
         )
 
-        return RoomStreamToken(None, self.main_store.get_current_events_token())
+        return self.main_store.get_room_max_token()
 
     async def persist_event(
         self, event: EventBase, context: EventContext, backfilled: bool = False
@@ -247,11 +247,10 @@ class EventsPersistenceStorage:
 
         await make_deferred_yieldable(deferred)
 
-        max_persisted_id = self.main_store.get_current_events_token()
         event_stream_id = event.internal_metadata.stream_ordering
 
         pos = PersistedEventPosition(self._instance_name, event_stream_id)
-        return pos, RoomStreamToken(None, max_persisted_id)
+        return pos, self.main_store.get_room_max_token()
 
     def _maybe_start_persisting(self, room_id: str):
         async def persisting_queue(item):
diff --git a/synapse/types.py b/synapse/types.py
index ec39f9e1e8..02bcc197ec 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -413,6 +413,18 @@ class RoomStreamToken:
             pass
         raise SynapseError(400, "Invalid token %r" % (string,))
 
+    def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken":
+        """Return a new token such that if an event is after both this token and
+        the other token, then its after the returned token too.
+        """
+
+        if self.topological or other.topological:
+            raise Exception("Can't advance topological tokens")
+
+        max_stream = max(self.stream, other.stream)
+
+        return RoomStreamToken(None, max_stream)
+
     def as_tuple(self) -> Tuple[Optional[int], int]:
         return (self.topological, self.stream)
 
@@ -458,31 +470,20 @@ class StreamToken:
     def room_stream_id(self):
         return self.room_key.stream
 
-    def is_after(self, other):
-        """Does this token contain events that the other doesn't?"""
-        return (
-            (other.room_stream_id < self.room_stream_id)
-            or (int(other.presence_key) < int(self.presence_key))
-            or (int(other.typing_key) < int(self.typing_key))
-            or (int(other.receipt_key) < int(self.receipt_key))
-            or (int(other.account_data_key) < int(self.account_data_key))
-            or (int(other.push_rules_key) < int(self.push_rules_key))
-            or (int(other.to_device_key) < int(self.to_device_key))
-            or (int(other.device_list_key) < int(self.device_list_key))
-            or (int(other.groups_key) < int(self.groups_key))
-        )
-
     def copy_and_advance(self, key, new_value) -> "StreamToken":
         """Advance the given key in the token to a new value if and only if the
         new value is after the old value.
         """
-        new_token = self.copy_and_replace(key, new_value)
         if key == "room_key":
-            new_id = new_token.room_stream_id
-            old_id = self.room_stream_id
-        else:
-            new_id = int(getattr(new_token, key))
-            old_id = int(getattr(self, key))
+            new_token = self.copy_and_replace(
+                "room_key", self.room_key.copy_and_advance(new_value)
+            )
+            return new_token
+
+        new_token = self.copy_and_replace(key, new_value)
+        new_id = int(getattr(new_token, key))
+        old_id = int(getattr(self, key))
+
         if old_id < new_id:
             return new_token
         else:
@@ -509,6 +510,18 @@ class PersistedEventPosition:
     def persisted_after(self, token: RoomStreamToken) -> bool:
         return token.stream < self.stream
 
+    def to_room_stream_token(self) -> RoomStreamToken:
+        """Converts the position to a room stream token such that events
+        persisted in the same room after this position will be after the
+        returned `RoomStreamToken`.
+
+        Note: no guarentees are made about ordering w.r.t. events in other
+        rooms.
+        """
+        # Doing the naive thing satisfies the desired properties described in
+        # the docstring.
+        return RoomStreamToken(None, self.stream)
+
 
 class ThirdPartyInstanceID(
     namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id"))
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 0a567b032f..a3287011e9 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -902,15 +902,15 @@ class RoomMessageListTestCase(RoomBase):
 
         # Send a first message in the room, which will be removed by the purge.
         first_event_id = self.helper.send(self.room_id, "message 1")["event_id"]
-        first_token = self.get_success(
-            store.get_topological_token_for_event(first_event_id)
+        first_token = str(
+            self.get_success(store.get_topological_token_for_event(first_event_id))
         )
 
         # Send a second message in the room, which won't be removed, and which we'll
         # use as the marker to purge events before.
         second_event_id = self.helper.send(self.room_id, "message 2")["event_id"]
-        second_token = self.get_success(
-            store.get_topological_token_for_event(second_event_id)
+        second_token = str(
+            self.get_success(store.get_topological_token_for_event(second_event_id))
         )
 
         # Send a third event in the room to ensure we don't fall under any edge case
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 918387733b..723cd28933 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -47,8 +47,8 @@ class PurgeTests(HomeserverTestCase):
         storage = self.hs.get_storage()
 
         # Get the topological token
-        event = self.get_success(
-            store.get_topological_token_for_event(last["event_id"])
+        event = str(
+            self.get_success(store.get_topological_token_for_event(last["event_id"]))
         )
 
         # Purge everything before this topological token
@@ -74,12 +74,10 @@ class PurgeTests(HomeserverTestCase):
         storage = self.hs.get_datastore()
 
         # Set the topological token higher than it should be
-        event = self.get_success(
+        token = self.get_success(
             storage.get_topological_token_for_event(last["event_id"])
         )
-        event = "t{}-{}".format(
-            *list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
-        )
+        event = "t{}-{}".format(token.topological + 1, token.stream + 1)
 
         # Purge everything before this topological token
         purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))