summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/17395.feature1
-rw-r--r--synapse/api/constants.py4
-rw-r--r--synapse/handlers/sliding_sync.py75
-rw-r--r--synapse/rest/client/sync.py1
-rw-r--r--synapse/storage/databases/main/stream.py35
-rw-r--r--synapse/types/handlers/__init__.py8
-rw-r--r--tests/handlers/test_sliding_sync.py68
-rw-r--r--tests/rest/client/test_sync.py96
-rw-r--r--tests/storage/test_stream.py41
9 files changed, 295 insertions, 34 deletions
diff --git a/changelog.d/17395.feature b/changelog.d/17395.feature
new file mode 100644
index 0000000000..0c95b9f4a9
--- /dev/null
+++ b/changelog.d/17395.feature
@@ -0,0 +1 @@
+Add `rooms.bump_stamp` for easier client-side sorting in experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint.
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 9265a271d2..12d18137e0 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -128,9 +128,13 @@ class EventTypes:
     SpaceParent: Final = "m.space.parent"
 
     Reaction: Final = "m.reaction"
+    Sticker: Final = "m.sticker"
+    LiveLocationShareStart: Final = "m.beacon_info"
 
     CallInvite: Final = "m.call.invite"
 
+    PollStart: Final = "m.poll.start"
+
 
 class ToDeviceEventTypes:
     RoomKeyRequest: Final = "m.room_key_request"
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index a1ddac903e..8e2f751c02 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -54,6 +54,17 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+# The event types that clients should consider as new activity.
+DEFAULT_BUMP_EVENT_TYPES = {
+    EventTypes.Message,
+    EventTypes.Encrypted,
+    EventTypes.Sticker,
+    EventTypes.CallInvite,
+    EventTypes.PollStart,
+    EventTypes.LiveLocationShareStart,
+}
+
+
 def filter_membership_for_sync(
     *, membership: str, user_id: str, sender: Optional[str]
 ) -> bool:
@@ -285,6 +296,7 @@ class _RoomMembershipForUser:
             range
     """
 
+    room_id: str
     event_id: Optional[str]
     event_pos: PersistedEventPosition
     membership: str
@@ -469,7 +481,9 @@ class SlidingSyncHandler:
                         #
                         # Both sides of range are inclusive so we `+ 1`
                         max_num_rooms = range[1] - range[0] + 1
-                        for room_id, _ in sorted_room_info[range[0] :]:
+                        for room_membership in sorted_room_info[range[0] :]:
+                            room_id = room_membership.room_id
+
                             if len(room_ids_in_list) >= max_num_rooms:
                                 break
 
@@ -519,7 +533,7 @@ class SlidingSyncHandler:
                 user=sync_config.user,
                 room_id=room_id,
                 room_sync_config=room_sync_config,
-                rooms_membership_for_user_at_to_token=sync_room_map[room_id],
+                room_membership_for_user_at_to_token=sync_room_map[room_id],
                 from_token=from_token,
                 to_token=to_token,
             )
@@ -591,6 +605,7 @@ class SlidingSyncHandler:
             # (below) because they are potentially from the current snapshot time
             # instead from the time of the `to_token`.
             room_for_user.room_id: _RoomMembershipForUser(
+                room_id=room_for_user.room_id,
                 event_id=room_for_user.event_id,
                 event_pos=room_for_user.event_pos,
                 membership=room_for_user.membership,
@@ -691,6 +706,7 @@ class SlidingSyncHandler:
                     is not None
                 ):
                     sync_room_id_set[room_id] = _RoomMembershipForUser(
+                        room_id=room_id,
                         event_id=first_membership_change_after_to_token.prev_event_id,
                         event_pos=first_membership_change_after_to_token.prev_event_pos,
                         membership=first_membership_change_after_to_token.prev_membership,
@@ -785,6 +801,7 @@ class SlidingSyncHandler:
             # is their own leave event
             if last_membership_change_in_from_to_range.membership == Membership.LEAVE:
                 filtered_sync_room_id_set[room_id] = _RoomMembershipForUser(
+                    room_id=room_id,
                     event_id=last_membership_change_in_from_to_range.event_id,
                     event_pos=last_membership_change_in_from_to_range.event_pos,
                     membership=last_membership_change_in_from_to_range.membership,
@@ -969,7 +986,7 @@ class SlidingSyncHandler:
         self,
         sync_room_map: Dict[str, _RoomMembershipForUser],
         to_token: StreamToken,
-    ) -> List[Tuple[str, _RoomMembershipForUser]]:
+    ) -> List[_RoomMembershipForUser]:
         """
         Sort by `stream_ordering` of the last event that the user should see in the
         room. `stream_ordering` is unique so we get a stable sort.
@@ -1007,12 +1024,17 @@ class SlidingSyncHandler:
             else:
                 # Otherwise, if the user has left/been invited/knocked/been banned from
                 # a room, they shouldn't see anything past that point.
+                #
+                # FIXME: It's possible that people should see beyond this point in
+                # invited/knocked cases if for example the room has
+                # `invite`/`world_readable` history visibility, see
+                # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932
                 last_activity_in_room_map[room_id] = room_for_user.event_pos.stream
 
         return sorted(
-            sync_room_map.items(),
+            sync_room_map.values(),
             # Sort by the last activity (stream_ordering) in the room
-            key=lambda room_info: last_activity_in_room_map[room_info[0]],
+            key=lambda room_info: last_activity_in_room_map[room_info.room_id],
             # We want descending order
             reverse=True,
         )
@@ -1022,7 +1044,7 @@ class SlidingSyncHandler:
         user: UserID,
         room_id: str,
         room_sync_config: RoomSyncConfig,
-        rooms_membership_for_user_at_to_token: _RoomMembershipForUser,
+        room_membership_for_user_at_to_token: _RoomMembershipForUser,
         from_token: Optional[StreamToken],
         to_token: StreamToken,
     ) -> SlidingSyncResult.RoomResult:
@@ -1036,7 +1058,7 @@ class SlidingSyncHandler:
             room_id: The room ID to fetch data for
             room_sync_config: Config for what data we should fetch for a room in the
                 sync response.
-            rooms_membership_for_user_at_to_token: Membership information for the user
+            room_membership_for_user_at_to_token: Membership information for the user
                 in the room at the time of `to_token`.
             from_token: The point in the stream to sync from.
             to_token: The point in the stream to sync up to.
@@ -1056,7 +1078,7 @@ class SlidingSyncHandler:
         if (
             room_sync_config.timeline_limit > 0
             # No timeline for invite/knock rooms (just `stripped_state`)
-            and rooms_membership_for_user_at_to_token.membership
+            and room_membership_for_user_at_to_token.membership
             not in (Membership.INVITE, Membership.KNOCK)
         ):
             limited = False
@@ -1069,12 +1091,12 @@ class SlidingSyncHandler:
             # We're going to paginate backwards from the `to_token`
             from_bound = to_token.room_key
             # People shouldn't see past their leave/ban event
-            if rooms_membership_for_user_at_to_token.membership in (
+            if room_membership_for_user_at_to_token.membership in (
                 Membership.LEAVE,
                 Membership.BAN,
             ):
                 from_bound = (
-                    rooms_membership_for_user_at_to_token.event_pos.to_room_stream_token()
+                    room_membership_for_user_at_to_token.event_pos.to_room_stream_token()
                 )
 
             # Determine whether we should limit the timeline to the token range.
@@ -1089,7 +1111,7 @@ class SlidingSyncHandler:
             to_bound = (
                 from_token.room_key
                 if from_token is not None
-                and not rooms_membership_for_user_at_to_token.newly_joined
+                and not room_membership_for_user_at_to_token.newly_joined
                 else None
             )
 
@@ -1126,7 +1148,7 @@ class SlidingSyncHandler:
                 self.storage_controllers,
                 user.to_string(),
                 timeline_events,
-                is_peeking=rooms_membership_for_user_at_to_token.membership
+                is_peeking=room_membership_for_user_at_to_token.membership
                 != Membership.JOIN,
                 filter_send_to_client=True,
             )
@@ -1181,16 +1203,16 @@ class SlidingSyncHandler:
         # Figure out any stripped state events for invite/knocks. This allows the
         # potential joiner to identify the room.
         stripped_state: List[JsonDict] = []
-        if rooms_membership_for_user_at_to_token.membership in (
+        if room_membership_for_user_at_to_token.membership in (
             Membership.INVITE,
             Membership.KNOCK,
         ):
             # This should never happen. If someone is invited/knocked on room, then
             # there should be an event for it.
-            assert rooms_membership_for_user_at_to_token.event_id is not None
+            assert room_membership_for_user_at_to_token.event_id is not None
 
             invite_or_knock_event = await self.store.get_event(
-                rooms_membership_for_user_at_to_token.event_id
+                room_membership_for_user_at_to_token.event_id
             )
 
             stripped_state = []
@@ -1206,7 +1228,7 @@ class SlidingSyncHandler:
             stripped_state.append(strip_event(invite_or_knock_event))
 
         # TODO: Handle state resets. For example, if we see
-        # `rooms_membership_for_user_at_to_token.membership = Membership.LEAVE` but
+        # `room_membership_for_user_at_to_token.membership = Membership.LEAVE` but
         # `required_state` doesn't include it, we should indicate to the client that a
         # state reset happened. Perhaps we should indicate this by setting `initial:
         # True` and empty `required_state`.
@@ -1226,7 +1248,7 @@ class SlidingSyncHandler:
         # `invite`/`knock` rooms only have `stripped_state`. See
         # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932
         room_state: Optional[StateMap[EventBase]] = None
-        if rooms_membership_for_user_at_to_token.membership not in (
+        if room_membership_for_user_at_to_token.membership not in (
             Membership.INVITE,
             Membership.KNOCK,
         ):
@@ -1303,7 +1325,7 @@ class SlidingSyncHandler:
                 # initial sync
                 if initial:
                     # People shouldn't see past their leave/ban event
-                    if rooms_membership_for_user_at_to_token.membership in (
+                    if room_membership_for_user_at_to_token.membership in (
                         Membership.LEAVE,
                         Membership.BAN,
                     ):
@@ -1311,7 +1333,7 @@ class SlidingSyncHandler:
                             room_id,
                             stream_position=to_token.copy_and_replace(
                                 StreamKeyType.ROOM,
-                                rooms_membership_for_user_at_to_token.event_pos.to_room_stream_token(),
+                                room_membership_for_user_at_to_token.event_pos.to_room_stream_token(),
                             ),
                             state_filter=state_filter,
                             # Partially-stated rooms should have all state events except for
@@ -1341,6 +1363,20 @@ class SlidingSyncHandler:
                     # we can return updates instead of the full required state.
                     raise NotImplementedError()
 
+        # Figure out the last bump event in the room
+        last_bump_event_result = (
+            await self.store.get_last_event_pos_in_room_before_stream_ordering(
+                room_id, to_token.room_key, event_types=DEFAULT_BUMP_EVENT_TYPES
+            )
+        )
+
+        # By default, just choose the membership event position
+        bump_stamp = room_membership_for_user_at_to_token.event_pos.stream
+        # But if we found a bump event, use that instead
+        if last_bump_event_result is not None:
+            _, bump_event_pos = last_bump_event_result
+            bump_stamp = bump_event_pos.stream
+
         return SlidingSyncResult.RoomResult(
             # TODO: Dummy value
             name=None,
@@ -1358,6 +1394,7 @@ class SlidingSyncHandler:
             prev_batch=prev_batch_token,
             limited=limited,
             num_live=num_live,
+            bump_stamp=bump_stamp,
             # TODO: Dummy values
             joined_count=0,
             invited_count=0,
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 2a22bc14ec..13aed1dc85 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -982,6 +982,7 @@ class SlidingSyncRestServlet(RestServlet):
         serialized_rooms: Dict[str, JsonDict] = {}
         for room_id, room_result in rooms.items():
             serialized_rooms[room_id] = {
+                "bump_stamp": room_result.bump_stamp,
                 "joined_count": room_result.joined_count,
                 "invited_count": room_result.invited_count,
                 "notification_count": room_result.notification_count,
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index d34376b8df..be81025355 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -1178,6 +1178,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         self,
         room_id: str,
         end_token: RoomStreamToken,
+        event_types: Optional[Collection[str]] = None,
     ) -> Optional[Tuple[str, PersistedEventPosition]]:
         """
         Returns the ID and event position of the last event in a room at or before a
@@ -1186,6 +1187,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         Args:
             room_id
             end_token: The token used to stream from
+            event_types: Optional allowlist of event types to filter by
 
         Returns:
             The ID of the most recent event and it's position, or None if there are no
@@ -1207,9 +1209,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             min_stream = end_token.stream
             max_stream = end_token.get_max_stream_pos()
 
-            # We use `union all` because we don't need any of the deduplication logic
-            # (`union` is really a union + distinct). `UNION ALL` does preserve the
-            # ordering of the operand queries but there is no actual gurantee that it
+            event_type_clause = ""
+            event_type_args: List[str] = []
+            if event_types is not None and len(event_types) > 0:
+                event_type_clause, event_type_args = make_in_list_sql_clause(
+                    txn.database_engine, "type", event_types
+                )
+                event_type_clause = f"AND {event_type_clause}"
+
+            # We use `UNION ALL` because we don't need any of the deduplication logic
+            # (`UNION` is really a `UNION` + `DISTINCT`). `UNION ALL` does preserve the
+            # ordering of the operand queries but there is no actual guarantee that it
             # has this behavior in all scenarios so we need the extra `ORDER BY` at the
             # bottom.
             sql = """
@@ -1218,6 +1228,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                     FROM events
                     LEFT JOIN rejections USING (event_id)
                     WHERE room_id = ?
+                        %s
                         AND ? < stream_ordering AND stream_ordering <= ?
                         AND NOT outlier
                         AND rejections.event_id IS NULL
@@ -1229,6 +1240,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                     FROM events
                     LEFT JOIN rejections USING (event_id)
                     WHERE room_id = ?
+                        %s
                         AND stream_ordering <= ?
                         AND NOT outlier
                         AND rejections.event_id IS NULL
@@ -1236,16 +1248,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                     LIMIT 1
                 ) AS b
                 ORDER BY stream_ordering DESC
-            """
+            """ % (
+                event_type_clause,
+                event_type_clause,
+            )
             txn.execute(
                 sql,
-                (
-                    room_id,
-                    min_stream,
-                    max_stream,
-                    room_id,
-                    min_stream,
-                ),
+                [room_id]
+                + event_type_args
+                + [min_stream, max_stream, room_id]
+                + event_type_args
+                + [min_stream],
             )
 
             for instance_name, stream_ordering, topological_ordering, event_id in txn:
diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py
index 3bd3268e59..43dcdf20dd 100644
--- a/synapse/types/handlers/__init__.py
+++ b/synapse/types/handlers/__init__.py
@@ -183,6 +183,13 @@ class SlidingSyncResult:
                 events because if a room not in the sliding window bumps into the window because
                 of an @mention it will have `initial: true` yet contain a single live event
                 (with potentially other old events in the timeline).
+            bump_stamp: The `stream_ordering` of the last event according to the
+                `bump_event_types`. This helps clients sort more readily without them
+                needing to pull in a bunch of the timeline to determine the last activity.
+                `bump_event_types` is a thing because for example, we don't want display
+                name changes to mark the room as unread and bump it to the top. For
+                encrypted rooms, we just have to consider any activity as a bump because we
+                can't see the content and the client has to figure it out for themselves.
             joined_count: The number of users with membership of join, including the client's
                 own user ID. (same as sync `v2 m.joined_member_count`)
             invited_count: The number of users with membership of invite. (same as sync v2
@@ -211,6 +218,7 @@ class SlidingSyncResult:
         limited: Optional[bool]
         # Only optional because it won't be included for invite/knock rooms with `stripped_state`
         num_live: Optional[int]
+        bump_stamp: int
         joined_count: int
         invited_count: int
         notification_count: int
diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py
index 5f83b637c5..9dd2363adc 100644
--- a/tests/handlers/test_sliding_sync.py
+++ b/tests/handlers/test_sliding_sync.py
@@ -2844,7 +2844,7 @@ class SortRoomsTestCase(HomeserverTestCase):
         )
 
         # Sort the rooms (what we're testing)
-        sorted_room_info = self.get_success(
+        sorted_sync_rooms = self.get_success(
             self.sliding_sync_handler.sort_rooms(
                 sync_room_map=sync_room_map,
                 to_token=after_rooms_token,
@@ -2852,7 +2852,7 @@ class SortRoomsTestCase(HomeserverTestCase):
         )
 
         self.assertEqual(
-            [room_id for room_id, _ in sorted_room_info],
+            [room_membership.room_id for room_membership in sorted_sync_rooms],
             [room_id2, room_id1],
         )
 
@@ -2927,7 +2927,7 @@ class SortRoomsTestCase(HomeserverTestCase):
         )
 
         # Sort the rooms (what we're testing)
-        sorted_room_info = self.get_success(
+        sorted_sync_rooms = self.get_success(
             self.sliding_sync_handler.sort_rooms(
                 sync_room_map=sync_room_map,
                 to_token=after_rooms_token,
@@ -2935,7 +2935,7 @@ class SortRoomsTestCase(HomeserverTestCase):
         )
 
         self.assertEqual(
-            [room_id for room_id, _ in sorted_room_info],
+            [room_membership.room_id for room_membership in sorted_sync_rooms],
             [room_id2, room_id1, room_id3],
             "Corresponding map to disambiguate the opaque room IDs: "
             + str(
@@ -2946,3 +2946,63 @@ class SortRoomsTestCase(HomeserverTestCase):
                 }
             ),
         )
+
+    def test_default_bump_event_types(self) -> None:
+        """
+        Test that we only consider the *latest* event in the room when sorting (not
+        `bump_event_types`).
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+
+        room_id1 = self.helper.create_room_as(
+            user1_id,
+            tok=user1_tok,
+        )
+        message_response = self.helper.send(room_id1, "message in room1", tok=user1_tok)
+        room_id2 = self.helper.create_room_as(
+            user1_id,
+            tok=user1_tok,
+        )
+        self.helper.send(room_id2, "message in room2", tok=user1_tok)
+
+        # Send a reaction in room1 which isn't in `DEFAULT_BUMP_EVENT_TYPES` but we only
+        # care about sorting by the *latest* event in the room.
+        self.helper.send_event(
+            room_id1,
+            type=EventTypes.Reaction,
+            content={
+                "m.relates_to": {
+                    "event_id": message_response["event_id"],
+                    "key": "👍",
+                    "rel_type": "m.annotation",
+                }
+            },
+            tok=user1_tok,
+        )
+
+        after_rooms_token = self.event_sources.get_current_token()
+
+        # Get the rooms the user should be syncing with
+        sync_room_map = self.get_success(
+            self.sliding_sync_handler.get_sync_room_ids_for_user(
+                UserID.from_string(user1_id),
+                from_token=None,
+                to_token=after_rooms_token,
+            )
+        )
+
+        # Sort the rooms (what we're testing)
+        sorted_sync_rooms = self.get_success(
+            self.sliding_sync_handler.sort_rooms(
+                sync_room_map=sync_room_map,
+                to_token=after_rooms_token,
+            )
+        )
+
+        self.assertEqual(
+            [room_membership.room_id for room_membership in sorted_sync_rooms],
+            # room1 sorts before room2 because it has the latest event (the reaction).
+            # We only care about the *latest* event in the room.
+            [room_id1, room_id2],
+        )
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index cb2888409e..6ff1f03c9a 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -2029,6 +2029,102 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
             channel.json_body["rooms"][room_id1],
         )
 
+    def test_rooms_bump_stamp(self) -> None:
+        """
+        Test that `bump_stamp` is present and pointing to relevant events.
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+
+        room_id1 = self.helper.create_room_as(
+            user1_id,
+            tok=user1_tok,
+        )
+        event_response1 = message_response = self.helper.send(
+            room_id1, "message in room1", tok=user1_tok
+        )
+        event_pos1 = self.get_success(
+            self.store.get_position_for_event(event_response1["event_id"])
+        )
+        room_id2 = self.helper.create_room_as(
+            user1_id,
+            tok=user1_tok,
+        )
+        send_response2 = self.helper.send(room_id2, "message in room2", tok=user1_tok)
+        event_pos2 = self.get_success(
+            self.store.get_position_for_event(send_response2["event_id"])
+        )
+
+        # Send a reaction in room1 but it shouldn't affect the `bump_stamp`
+        # because reactions are not part of the `DEFAULT_BUMP_EVENT_TYPES`
+        self.helper.send_event(
+            room_id1,
+            type=EventTypes.Reaction,
+            content={
+                "m.relates_to": {
+                    "event_id": message_response["event_id"],
+                    "key": "👍",
+                    "rel_type": "m.annotation",
+                }
+            },
+            tok=user1_tok,
+        )
+
+        # Make the Sliding Sync request
+        timeline_limit = 100
+        channel = self.make_request(
+            "POST",
+            self.sync_endpoint,
+            {
+                "lists": {
+                    "foo-list": {
+                        "ranges": [[0, 1]],
+                        "required_state": [],
+                        "timeline_limit": timeline_limit,
+                    }
+                }
+            },
+            access_token=user1_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # Make sure it has the foo-list we requested
+        self.assertListEqual(
+            list(channel.json_body["lists"].keys()),
+            ["foo-list"],
+            channel.json_body["lists"].keys(),
+        )
+
+        # Make sure the list includes the rooms in the right order
+        self.assertListEqual(
+            list(channel.json_body["lists"]["foo-list"]["ops"]),
+            [
+                {
+                    "op": "SYNC",
+                    "range": [0, 1],
+                    # room1 sorts before room2 because it has the latest event (the
+                    # reaction)
+                    "room_ids": [room_id1, room_id2],
+                }
+            ],
+            channel.json_body["lists"]["foo-list"],
+        )
+
+        # The `bump_stamp` for room1 should point at the latest message (not the
+        # reaction since it's not one of the `DEFAULT_BUMP_EVENT_TYPES`)
+        self.assertEqual(
+            channel.json_body["rooms"][room_id1]["bump_stamp"],
+            event_pos1.stream,
+            channel.json_body["rooms"][room_id1],
+        )
+
+        # The `bump_stamp` for room2 should point at the latest message
+        self.assertEqual(
+            channel.json_body["rooms"][room_id2]["bump_stamp"],
+            event_pos2.stream,
+            channel.json_body["rooms"][room_id2],
+        )
+
     def test_rooms_newly_joined_incremental_sync(self) -> None:
         """
         Test that when we make an incremental sync with a `newly_joined` `rooms`, we are
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index aad46b1b44..9dea1af8ea 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -556,6 +556,47 @@ class GetLastEventInRoomBeforeStreamOrderingTestCase(HomeserverTestCase):
             ),
         )
 
+    def test_restrict_event_types(self) -> None:
+        """
+        Test that we only consider given `event_types` when finding the last event
+        before a token.
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+
+        room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
+        event_response = self.helper.send_event(
+            room_id1,
+            type="org.matrix.special_message",
+            content={"body": "before1, target!"},
+            tok=user1_tok,
+        )
+        self.helper.send(room_id1, "before2", tok=user1_tok)
+
+        after_room_token = self.event_sources.get_current_token()
+
+        # Send some events after the token
+        self.helper.send_event(
+            room_id1,
+            type="org.matrix.special_message",
+            content={"body": "after1"},
+            tok=user1_tok,
+        )
+        self.helper.send(room_id1, "after2", tok=user1_tok)
+
+        last_event_result = self.get_success(
+            self.store.get_last_event_pos_in_room_before_stream_ordering(
+                room_id=room_id1,
+                end_token=after_room_token.room_key,
+                event_types=["org.matrix.special_message"],
+            )
+        )
+        assert last_event_result is not None
+        last_event_id, _ = last_event_result
+
+        # Make sure it's the last event before the token
+        self.assertEqual(last_event_id, event_response["event_id"])
+
 
 class GetCurrentStateDeltaMembershipChangesForUserTestCase(HomeserverTestCase):
     """