summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/16427.misc1
-rw-r--r--synapse/handlers/admin.py4
-rw-r--r--synapse/handlers/initial_sync.py3
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/handlers/sync.py2
-rw-r--r--synapse/rest/admin/__init__.py2
-rw-r--r--synapse/storage/databases/main/stream.py22
-rw-r--r--synapse/types/__init__.py132
-rw-r--r--tests/handlers/test_appservice.py8
9 files changed, 115 insertions, 61 deletions
diff --git a/changelog.d/16427.misc b/changelog.d/16427.misc
new file mode 100644
index 0000000000..44f0e0595e
--- /dev/null
+++ b/changelog.d/16427.misc
@@ -0,0 +1 @@
+Factor out `MultiWriter` token from `RoomStreamToken`.
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index ba9704a065..97fd1fd427 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -171,8 +171,8 @@ class AdminHandler:
             else:
                 stream_ordering = room.stream_ordering
 
-            from_key = RoomStreamToken(0, 0)
-            to_key = RoomStreamToken(None, stream_ordering)
+            from_key = RoomStreamToken(topological=0, stream=0)
+            to_key = RoomStreamToken(stream=stream_ordering)
 
             # Events that we've processed in this room
             written_events: Set[str] = set()
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 5737f8014d..c34bd7db95 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -192,8 +192,7 @@ class InitialSyncHandler:
                     )
                 elif event.membership == Membership.LEAVE:
                     room_end_token = RoomStreamToken(
-                        None,
-                        event.stream_ordering,
+                        stream=event.stream_ordering,
                     )
                     deferred_room_state = run_in_background(
                         self._state_storage_controller.get_state_for_events,
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a0c3b16819..4cdf0a8502 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1708,7 +1708,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
 
         if from_key.topological:
             logger.warning("Stream has topological part!!!! %r", from_key)
-            from_key = RoomStreamToken(None, from_key.stream)
+            from_key = RoomStreamToken(stream=from_key.stream)
 
         app_service = self.store.get_app_service_by_user_id(user.to_string())
         if app_service:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 7bd42f635f..744e080309 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -2333,7 +2333,7 @@ class SyncHandler:
                             continue
 
                 leave_token = now_token.copy_and_replace(
-                    StreamKeyType.ROOM, RoomStreamToken(None, event.stream_ordering)
+                    StreamKeyType.ROOM, RoomStreamToken(stream=event.stream_ordering)
                 )
                 room_entries.append(
                     RoomSyncResultBuilder(
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index e42dade246..9bd0d764f8 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -146,7 +146,7 @@ class PurgeHistoryRestServlet(RestServlet):
             # RoomStreamToken expects [int] not Optional[int]
             assert event.internal_metadata.stream_ordering is not None
             room_token = RoomStreamToken(
-                event.depth, event.internal_metadata.stream_ordering
+                topological=event.depth, stream=event.internal_metadata.stream_ordering
             )
             token = await room_token.to_string(self.store)
 
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 5a3611c415..ea06e4eee0 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -266,7 +266,7 @@ def generate_next_token(
         # when we are going backwards so we subtract one from the
         # stream part.
         last_stream_ordering -= 1
-    return RoomStreamToken(last_topo_ordering, last_stream_ordering)
+    return RoomStreamToken(topological=last_topo_ordering, stream=last_stream_ordering)
 
 
 def _make_generic_sql_bound(
@@ -558,7 +558,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 if p > min_pos
             }
 
-        return RoomStreamToken(None, min_pos, immutabledict(positions))
+        return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions))
 
     async def get_room_events_stream_for_rooms(
         self,
@@ -708,7 +708,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             ret.reverse()
 
         if rows:
-            key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
+            key = RoomStreamToken(stream=min(r.stream_ordering for r in rows))
         else:
             # Assume we didn't get anything because there was nothing to
             # get.
@@ -969,7 +969,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         topo = await self.db_pool.runInteraction(
             "_get_max_topological_txn", self._get_max_topological_txn, room_id
         )
-        return RoomStreamToken(topo, stream_ordering)
+        return RoomStreamToken(topological=topo, stream=stream_ordering)
 
     @overload
     def get_stream_id_for_event_txn(
@@ -1033,7 +1033,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             retcols=("stream_ordering", "topological_ordering"),
             desc="get_topological_token_for_event",
         )
-        return RoomStreamToken(row["topological_ordering"], row["stream_ordering"])
+        return RoomStreamToken(
+            topological=row["topological_ordering"], stream=row["stream_ordering"]
+        )
 
     async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
         """Gets the topological token in a room after or at the given stream
@@ -1114,8 +1116,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             else:
                 topo = None
             internal = event.internal_metadata
-            internal.before = RoomStreamToken(topo, stream - 1)
-            internal.after = RoomStreamToken(topo, stream)
+            internal.before = RoomStreamToken(topological=topo, stream=stream - 1)
+            internal.after = RoomStreamToken(topological=topo, stream=stream)
             internal.order = (int(topo) if topo else 0, int(stream))
 
     async def get_events_around(
@@ -1191,11 +1193,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         # Paginating backwards includes the event at the token, but paginating
         # forward doesn't.
         before_token = RoomStreamToken(
-            results["topological_ordering"] - 1, results["stream_ordering"]
+            topological=results["topological_ordering"] - 1,
+            stream=results["stream_ordering"],
         )
 
         after_token = RoomStreamToken(
-            results["topological_ordering"], results["stream_ordering"]
+            topological=results["topological_ordering"],
+            stream=results["stream_ordering"],
         )
 
         rows, start_token = self._paginate_room_events_txn(
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 406d5b1611..09a88c86a7 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -61,6 +61,8 @@ from synapse.util.cancellation import cancellable
 from synapse.util.stringutils import parse_and_validate_server_name
 
 if TYPE_CHECKING:
+    from typing_extensions import Self
+
     from synapse.appservice.api import ApplicationService
     from synapse.storage.databases.main import DataStore, PurgeEventsStore
     from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
@@ -437,7 +439,78 @@ def map_username_to_mxid_localpart(
 
 
 @attr.s(frozen=True, slots=True, order=False)
-class RoomStreamToken:
+class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
+    """An abstract stream token class for streams that supports multiple
+    writers.
+
+    This works by keeping track of the stream position of each writer,
+    represented by a default `stream` attribute and a map of instance name to
+    stream position of any writers that are ahead of the default stream
+    position.
+    """
+
+    stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True)
+
+    instance_map: "immutabledict[str, int]" = attr.ib(
+        factory=immutabledict,
+        validator=attr.validators.deep_mapping(
+            key_validator=attr.validators.instance_of(str),
+            value_validator=attr.validators.instance_of(int),
+            mapping_validator=attr.validators.instance_of(immutabledict),
+        ),
+        kw_only=True,
+    )
+
+    @classmethod
+    @abc.abstractmethod
+    async def parse(cls, store: "DataStore", string: str) -> "Self":
+        """Parse the string representation of the token."""
+        ...
+
+    @abc.abstractmethod
+    async def to_string(self, store: "DataStore") -> str:
+        """Serialize the token into its string representation."""
+        ...
+
+    def copy_and_advance(self, other: "Self") -> "Self":
+        """Return a new token such that if an event is after both this token and
+        the other token, then its after the returned token too.
+        """
+
+        max_stream = max(self.stream, other.stream)
+
+        instance_map = {
+            instance: max(
+                self.instance_map.get(instance, self.stream),
+                other.instance_map.get(instance, other.stream),
+            )
+            for instance in set(self.instance_map).union(other.instance_map)
+        }
+
+        return attr.evolve(
+            self, stream=max_stream, instance_map=immutabledict(instance_map)
+        )
+
+    def get_max_stream_pos(self) -> int:
+        """Get the maximum stream position referenced in this token.
+
+        The corresponding "min" position is, by definition just `self.stream`.
+
+        This is used to handle tokens that have non-empty `instance_map`, and so
+        reference stream positions after the `self.stream` position.
+        """
+        return max(self.instance_map.values(), default=self.stream)
+
+    def get_stream_pos_for_instance(self, instance_name: str) -> int:
+        """Get the stream position that the given writer was at at this token."""
+
+        # If we don't have an entry for the instance we can assume that it was
+        # at `self.stream`.
+        return self.instance_map.get(instance_name, self.stream)
+
+
+@attr.s(frozen=True, slots=True, order=False)
+class RoomStreamToken(AbstractMultiWriterStreamToken):
     """Tokens are positions between events. The token "s1" comes after event 1.
 
             s0    s1
@@ -514,16 +587,8 @@ class RoomStreamToken:
 
     topological: Optional[int] = attr.ib(
         validator=attr.validators.optional(attr.validators.instance_of(int)),
-    )
-    stream: int = attr.ib(validator=attr.validators.instance_of(int))
-
-    instance_map: "immutabledict[str, int]" = attr.ib(
-        factory=immutabledict,
-        validator=attr.validators.deep_mapping(
-            key_validator=attr.validators.instance_of(str),
-            value_validator=attr.validators.instance_of(int),
-            mapping_validator=attr.validators.instance_of(immutabledict),
-        ),
+        kw_only=True,
+        default=None,
     )
 
     def __attrs_post_init__(self) -> None:
@@ -583,17 +648,7 @@ class RoomStreamToken:
         if self.topological or other.topological:
             raise Exception("Can't advance topological tokens")
 
-        max_stream = max(self.stream, other.stream)
-
-        instance_map = {
-            instance: max(
-                self.instance_map.get(instance, self.stream),
-                other.instance_map.get(instance, other.stream),
-            )
-            for instance in set(self.instance_map).union(other.instance_map)
-        }
-
-        return RoomStreamToken(None, max_stream, immutabledict(instance_map))
+        return super().copy_and_advance(other)
 
     def as_historical_tuple(self) -> Tuple[int, int]:
         """Returns a tuple of `(topological, stream)` for historical tokens.
@@ -619,16 +674,6 @@ class RoomStreamToken:
         # at `self.stream`.
         return self.instance_map.get(instance_name, self.stream)
 
-    def get_max_stream_pos(self) -> int:
-        """Get the maximum stream position referenced in this token.
-
-        The corresponding "min" position is, by definition just `self.stream`.
-
-        This is used to handle tokens that have non-empty `instance_map`, and so
-        reference stream positions after the `self.stream` position.
-        """
-        return max(self.instance_map.values(), default=self.stream)
-
     async def to_string(self, store: "DataStore") -> str:
         if self.topological is not None:
             return "t%d-%d" % (self.topological, self.stream)
@@ -838,23 +883,28 @@ class StreamToken:
         return getattr(self, key.value)
 
 
-StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0, 0)
+StreamToken.START = StreamToken(RoomStreamToken(stream=0), 0, 0, 0, 0, 0, 0, 0, 0, 0)
 
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
-class PersistedEventPosition:
-    """Position of a newly persisted event with instance that persisted it.
-
-    This can be used to test whether the event is persisted before or after a
-    RoomStreamToken.
-    """
+class PersistedPosition:
+    """Position of a newly persisted row with instance that persisted it."""
 
     instance_name: str
     stream: int
 
-    def persisted_after(self, token: RoomStreamToken) -> bool:
+    def persisted_after(self, token: AbstractMultiWriterStreamToken) -> bool:
         return token.get_stream_pos_for_instance(self.instance_name) < self.stream
 
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class PersistedEventPosition(PersistedPosition):
+    """Position of a newly persisted event with instance that persisted it.
+
+    This can be used to test whether the event is persisted before or after a
+    RoomStreamToken.
+    """
+
     def to_room_stream_token(self) -> RoomStreamToken:
         """Converts the position to a room stream token such that events
         persisted in the same room after this position will be after the
@@ -865,7 +915,7 @@ class PersistedEventPosition:
         """
         # Doing the naive thing satisfies the desired properties described in
         # the docstring.
-        return RoomStreamToken(None, self.stream)
+        return RoomStreamToken(stream=self.stream)
 
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 8ce6ccf529..867dbd6001 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -86,7 +86,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
                 [event],
             ]
         )
-        self.handler.notify_interested_services(RoomStreamToken(None, 1))
+        self.handler.notify_interested_services(RoomStreamToken(stream=1))
 
         self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
             interested_service, events=[event]
@@ -107,7 +107,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             ]
         )
         self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]])
-        self.handler.notify_interested_services(RoomStreamToken(None, 0))
+        self.handler.notify_interested_services(RoomStreamToken(stream=0))
 
         self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
 
@@ -126,7 +126,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             ]
         )
 
-        self.handler.notify_interested_services(RoomStreamToken(None, 0))
+        self.handler.notify_interested_services(RoomStreamToken(stream=0))
 
         self.assertFalse(
             self.mock_as_api.query_user.called,
@@ -441,7 +441,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
         self.get_success(
             self.hs.get_application_service_handler()._notify_interested_services(
                 RoomStreamToken(
-                    None, self.hs.get_application_service_handler().current_max
+                    stream=self.hs.get_application_service_handler().current_max
                 )
             )
         )