diff --git a/changelog.d/16427.misc b/changelog.d/16427.misc
new file mode 100644
index 0000000000..44f0e0595e
--- /dev/null
+++ b/changelog.d/16427.misc
@@ -0,0 +1 @@
+Factor out `MultiWriter` token from `RoomStreamToken`.
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index ba9704a065..97fd1fd427 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -171,8 +171,8 @@ class AdminHandler:
else:
stream_ordering = room.stream_ordering
- from_key = RoomStreamToken(0, 0)
- to_key = RoomStreamToken(None, stream_ordering)
+ from_key = RoomStreamToken(topological=0, stream=0)
+ to_key = RoomStreamToken(stream=stream_ordering)
# Events that we've processed in this room
written_events: Set[str] = set()
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 5737f8014d..c34bd7db95 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -192,8 +192,7 @@ class InitialSyncHandler:
)
elif event.membership == Membership.LEAVE:
room_end_token = RoomStreamToken(
- None,
- event.stream_ordering,
+ stream=event.stream_ordering,
)
deferred_room_state = run_in_background(
self._state_storage_controller.get_state_for_events,
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a0c3b16819..4cdf0a8502 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1708,7 +1708,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
if from_key.topological:
logger.warning("Stream has topological part!!!! %r", from_key)
- from_key = RoomStreamToken(None, from_key.stream)
+ from_key = RoomStreamToken(stream=from_key.stream)
app_service = self.store.get_app_service_by_user_id(user.to_string())
if app_service:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 7bd42f635f..744e080309 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -2333,7 +2333,7 @@ class SyncHandler:
continue
leave_token = now_token.copy_and_replace(
- StreamKeyType.ROOM, RoomStreamToken(None, event.stream_ordering)
+ StreamKeyType.ROOM, RoomStreamToken(stream=event.stream_ordering)
)
room_entries.append(
RoomSyncResultBuilder(
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index e42dade246..9bd0d764f8 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -146,7 +146,7 @@ class PurgeHistoryRestServlet(RestServlet):
# RoomStreamToken expects [int] not Optional[int]
assert event.internal_metadata.stream_ordering is not None
room_token = RoomStreamToken(
- event.depth, event.internal_metadata.stream_ordering
+ topological=event.depth, stream=event.internal_metadata.stream_ordering
)
token = await room_token.to_string(self.store)
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 5a3611c415..ea06e4eee0 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -266,7 +266,7 @@ def generate_next_token(
# when we are going backwards so we subtract one from the
# stream part.
last_stream_ordering -= 1
- return RoomStreamToken(last_topo_ordering, last_stream_ordering)
+ return RoomStreamToken(topological=last_topo_ordering, stream=last_stream_ordering)
def _make_generic_sql_bound(
@@ -558,7 +558,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if p > min_pos
}
- return RoomStreamToken(None, min_pos, immutabledict(positions))
+ return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions))
async def get_room_events_stream_for_rooms(
self,
@@ -708,7 +708,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
ret.reverse()
if rows:
- key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
+ key = RoomStreamToken(stream=min(r.stream_ordering for r in rows))
else:
# Assume we didn't get anything because there was nothing to
# get.
@@ -969,7 +969,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
topo = await self.db_pool.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn, room_id
)
- return RoomStreamToken(topo, stream_ordering)
+ return RoomStreamToken(topological=topo, stream=stream_ordering)
@overload
def get_stream_id_for_event_txn(
@@ -1033,7 +1033,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
)
- return RoomStreamToken(row["topological_ordering"], row["stream_ordering"])
+ return RoomStreamToken(
+ topological=row["topological_ordering"], stream=row["stream_ordering"]
+ )
async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
"""Gets the topological token in a room after or at the given stream
@@ -1114,8 +1116,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
else:
topo = None
internal = event.internal_metadata
- internal.before = RoomStreamToken(topo, stream - 1)
- internal.after = RoomStreamToken(topo, stream)
+ internal.before = RoomStreamToken(topological=topo, stream=stream - 1)
+ internal.after = RoomStreamToken(topological=topo, stream=stream)
internal.order = (int(topo) if topo else 0, int(stream))
async def get_events_around(
@@ -1191,11 +1193,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
- results["topological_ordering"] - 1, results["stream_ordering"]
+ topological=results["topological_ordering"] - 1,
+ stream=results["stream_ordering"],
)
after_token = RoomStreamToken(
- results["topological_ordering"], results["stream_ordering"]
+ topological=results["topological_ordering"],
+ stream=results["stream_ordering"],
)
rows, start_token = self._paginate_room_events_txn(
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 406d5b1611..09a88c86a7 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -61,6 +61,8 @@ from synapse.util.cancellation import cancellable
from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING:
+ from typing_extensions import Self
+
from synapse.appservice.api import ApplicationService
from synapse.storage.databases.main import DataStore, PurgeEventsStore
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
@@ -437,7 +439,78 @@ def map_username_to_mxid_localpart(
@attr.s(frozen=True, slots=True, order=False)
-class RoomStreamToken:
+class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
+ """An abstract stream token class for streams that supports multiple
+ writers.
+
+ This works by keeping track of the stream position of each writer,
+ represented by a default `stream` attribute and a map of instance name to
+ stream position of any writers that are ahead of the default stream
+ position.
+ """
+
+ stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True)
+
+ instance_map: "immutabledict[str, int]" = attr.ib(
+ factory=immutabledict,
+ validator=attr.validators.deep_mapping(
+ key_validator=attr.validators.instance_of(str),
+ value_validator=attr.validators.instance_of(int),
+ mapping_validator=attr.validators.instance_of(immutabledict),
+ ),
+ kw_only=True,
+ )
+
+ @classmethod
+ @abc.abstractmethod
+ async def parse(cls, store: "DataStore", string: str) -> "Self":
+ """Parse the string representation of the token."""
+ ...
+
+ @abc.abstractmethod
+ async def to_string(self, store: "DataStore") -> str:
+ """Serialize the token into its string representation."""
+ ...
+
+ def copy_and_advance(self, other: "Self") -> "Self":
+ """Return a new token such that if an event is after both this token and
+ the other token, then its after the returned token too.
+ """
+
+ max_stream = max(self.stream, other.stream)
+
+ instance_map = {
+ instance: max(
+ self.instance_map.get(instance, self.stream),
+ other.instance_map.get(instance, other.stream),
+ )
+ for instance in set(self.instance_map).union(other.instance_map)
+ }
+
+ return attr.evolve(
+ self, stream=max_stream, instance_map=immutabledict(instance_map)
+ )
+
+ def get_max_stream_pos(self) -> int:
+ """Get the maximum stream position referenced in this token.
+
+ The corresponding "min" position is, by definition just `self.stream`.
+
+ This is used to handle tokens that have non-empty `instance_map`, and so
+ reference stream positions after the `self.stream` position.
+ """
+ return max(self.instance_map.values(), default=self.stream)
+
+ def get_stream_pos_for_instance(self, instance_name: str) -> int:
+ """Get the stream position that the given writer was at at this token."""
+
+ # If we don't have an entry for the instance we can assume that it was
+ # at `self.stream`.
+ return self.instance_map.get(instance_name, self.stream)
+
+
+@attr.s(frozen=True, slots=True, order=False)
+class RoomStreamToken(AbstractMultiWriterStreamToken):
"""Tokens are positions between events. The token "s1" comes after event 1.
s0 s1
@@ -514,16 +587,8 @@ class RoomStreamToken:
topological: Optional[int] = attr.ib(
validator=attr.validators.optional(attr.validators.instance_of(int)),
- )
- stream: int = attr.ib(validator=attr.validators.instance_of(int))
-
- instance_map: "immutabledict[str, int]" = attr.ib(
- factory=immutabledict,
- validator=attr.validators.deep_mapping(
- key_validator=attr.validators.instance_of(str),
- value_validator=attr.validators.instance_of(int),
- mapping_validator=attr.validators.instance_of(immutabledict),
- ),
+ kw_only=True,
+ default=None,
)
def __attrs_post_init__(self) -> None:
@@ -583,17 +648,7 @@ class RoomStreamToken:
if self.topological or other.topological:
raise Exception("Can't advance topological tokens")
- max_stream = max(self.stream, other.stream)
-
- instance_map = {
- instance: max(
- self.instance_map.get(instance, self.stream),
- other.instance_map.get(instance, other.stream),
- )
- for instance in set(self.instance_map).union(other.instance_map)
- }
-
- return RoomStreamToken(None, max_stream, immutabledict(instance_map))
+ return super().copy_and_advance(other)
def as_historical_tuple(self) -> Tuple[int, int]:
"""Returns a tuple of `(topological, stream)` for historical tokens.
@@ -619,16 +674,6 @@ class RoomStreamToken:
# at `self.stream`.
return self.instance_map.get(instance_name, self.stream)
- def get_max_stream_pos(self) -> int:
- """Get the maximum stream position referenced in this token.
-
- The corresponding "min" position is, by definition just `self.stream`.
-
- This is used to handle tokens that have non-empty `instance_map`, and so
- reference stream positions after the `self.stream` position.
- """
- return max(self.instance_map.values(), default=self.stream)
-
async def to_string(self, store: "DataStore") -> str:
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
@@ -838,23 +883,28 @@ class StreamToken:
return getattr(self, key.value)
-StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0, 0)
+StreamToken.START = StreamToken(RoomStreamToken(stream=0), 0, 0, 0, 0, 0, 0, 0, 0, 0)
@attr.s(slots=True, frozen=True, auto_attribs=True)
-class PersistedEventPosition:
- """Position of a newly persisted event with instance that persisted it.
-
- This can be used to test whether the event is persisted before or after a
- RoomStreamToken.
- """
+class PersistedPosition:
+ """Position of a newly persisted row with instance that persisted it."""
instance_name: str
stream: int
- def persisted_after(self, token: RoomStreamToken) -> bool:
+ def persisted_after(self, token: AbstractMultiWriterStreamToken) -> bool:
return token.get_stream_pos_for_instance(self.instance_name) < self.stream
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class PersistedEventPosition(PersistedPosition):
+ """Position of a newly persisted event with instance that persisted it.
+
+ This can be used to test whether the event is persisted before or after a
+ RoomStreamToken.
+ """
+
def to_room_stream_token(self) -> RoomStreamToken:
"""Converts the position to a room stream token such that events
persisted in the same room after this position will be after the
@@ -865,7 +915,7 @@ class PersistedEventPosition:
"""
# Doing the naive thing satisfies the desired properties described in
# the docstring.
- return RoomStreamToken(None, self.stream)
+ return RoomStreamToken(stream=self.stream)
@attr.s(slots=True, frozen=True, auto_attribs=True)
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 8ce6ccf529..867dbd6001 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -86,7 +86,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
[event],
]
)
- self.handler.notify_interested_services(RoomStreamToken(None, 1))
+ self.handler.notify_interested_services(RoomStreamToken(stream=1))
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
interested_service, events=[event]
@@ -107,7 +107,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
]
)
self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]])
- self.handler.notify_interested_services(RoomStreamToken(None, 0))
+ self.handler.notify_interested_services(RoomStreamToken(stream=0))
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
@@ -126,7 +126,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
]
)
- self.handler.notify_interested_services(RoomStreamToken(None, 0))
+ self.handler.notify_interested_services(RoomStreamToken(stream=0))
self.assertFalse(
self.mock_as_api.query_user.called,
@@ -441,7 +441,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
self.get_success(
self.hs.get_application_service_handler()._notify_interested_services(
RoomStreamToken(
- None, self.hs.get_application_service_handler().current_max
+ stream=self.hs.get_application_service_handler().current_max
)
)
)
|