diff --git a/changelog.d/17295.bugfix b/changelog.d/17295.bugfix
new file mode 100644
index 0000000000..4484253bb8
--- /dev/null
+++ b/changelog.d/17295.bugfix
@@ -0,0 +1 @@
+Fix edge case in `/sync` returning the wrong the state when using sharded event persisters.
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 7ab6003f61..61373f0bfb 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -914,12 +914,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def get_last_event_in_room_before_stream_ordering_txn(
txn: LoggingTransaction,
) -> Optional[str]:
- # We need to handle the fact that the stream tokens can be vector
- # clocks. We do this by getting all rows between the minimum and
- # maximum stream ordering in the token, plus one row less than the
- # minimum stream ordering. We then filter the results against the
- # token and return the first row that matches.
-
+ # We're looking for the closest event at or before the token. We need to
+ # handle the fact that the stream token can be a vector clock (with an
+ # `instance_map`) and events can be persisted on different instances
+ # (sharded event persisters). The first subquery handles the events that
+ # would be within the vector clock and gets all rows between the minimum and
+ # maximum stream ordering in the token which need to be filtered against the
+ # `instance_map`. The second subquery handles the "before" case and finds
+ # the first row before the token. We then filter out any results past the
+ # token's vector clock and return the first row that matches.
+ min_stream = end_token.stream
+ max_stream = end_token.get_max_stream_pos()
+
+ # We use `union all` because we don't need any of the deduplication logic
+ # (`union` is really a union + distinct). `UNION ALL` does preserve the
+ # ordering of the operand queries but there is no actual gurantee that it
+ # has this behavior in all scenarios so we need the extra `ORDER BY` at the
+ # bottom.
sql = """
SELECT * FROM (
SELECT instance_name, stream_ordering, topological_ordering, event_id
@@ -931,7 +942,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
AND rejections.event_id IS NULL
ORDER BY stream_ordering DESC
) AS a
- UNION
+ UNION ALL
SELECT * FROM (
SELECT instance_name, stream_ordering, topological_ordering, event_id
FROM events
@@ -943,15 +954,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
ORDER BY stream_ordering DESC
LIMIT 1
) AS b
+ ORDER BY stream_ordering DESC
"""
txn.execute(
sql,
(
room_id,
- end_token.stream,
- end_token.get_max_stream_pos(),
+ min_stream,
+ max_stream,
room_id,
- end_token.stream,
+ min_stream,
),
)
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index 2029cd9c68..ee34baf46f 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -19,7 +19,10 @@
#
#
-from typing import List
+import logging
+from typing import List, Tuple
+
+from immutabledict import immutabledict
from twisted.test.proto_helpers import MemoryReactor
@@ -28,11 +31,13 @@ from synapse.api.filtering import Filter
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
-from synapse.types import JsonDict
+from synapse.types import JsonDict, PersistedEventPosition, RoomStreamToken
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
+logger = logging.getLogger(__name__)
+
class PaginationTestCase(HomeserverTestCase):
"""
@@ -268,3 +273,263 @@ class PaginationTestCase(HomeserverTestCase):
}
chunk = self._filter_messages(filter)
self.assertEqual(chunk, [self.event_id_1, self.event_id_2, self.event_id_none])
+
+
+class GetLastEventInRoomBeforeStreamOrderingTestCase(HomeserverTestCase):
+ """
+ Test `get_last_event_in_room_before_stream_ordering(...)`
+ """
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.store = hs.get_datastores().main
+ self.event_sources = hs.get_event_sources()
+
+ def _update_persisted_instance_name_for_event(
+ self, event_id: str, instance_name: str
+ ) -> None:
+ """
+ Update the `instance_name` that persisted the the event in the database.
+ """
+ return self.get_success(
+ self.store.db_pool.simple_update_one(
+ "events",
+ keyvalues={"event_id": event_id},
+ updatevalues={"instance_name": instance_name},
+ )
+ )
+
+ def _send_event_on_instance(
+ self, instance_name: str, room_id: str, access_token: str
+ ) -> Tuple[JsonDict, PersistedEventPosition]:
+ """
+ Send an event in a room and mimic that it was persisted by a specific
+ instance/worker.
+ """
+ event_response = self.helper.send(
+ room_id, f"{instance_name} message", tok=access_token
+ )
+
+ self._update_persisted_instance_name_for_event(
+ event_response["event_id"], instance_name
+ )
+
+ event_pos = self.get_success(
+ self.store.get_position_for_event(event_response["event_id"])
+ )
+
+ return event_response, event_pos
+
+ def test_before_room_created(self) -> None:
+ """
+ Test that no event is returned if we are using a token before the room was even created
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ before_room_token = self.event_sources.get_current_token()
+
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
+
+ last_event = self.get_success(
+ self.store.get_last_event_in_room_before_stream_ordering(
+ room_id=room_id,
+ end_token=before_room_token.room_key,
+ )
+ )
+
+ self.assertIsNone(last_event)
+
+ def test_after_room_created(self) -> None:
+ """
+ Test that an event is returned if we are using a token after the room was created
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
+
+ after_room_token = self.event_sources.get_current_token()
+
+ last_event = self.get_success(
+ self.store.get_last_event_in_room_before_stream_ordering(
+ room_id=room_id,
+ end_token=after_room_token.room_key,
+ )
+ )
+
+ self.assertIsNotNone(last_event)
+
+ def test_activity_in_other_rooms(self) -> None:
+ """
+ Test to make sure that the last event in the room is returned even if the
+ `stream_ordering` has advanced from activity in other rooms.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
+ event_response = self.helper.send(room_id1, "target!", tok=user1_tok)
+ # Create another room to advance the stream_ordering
+ self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
+
+ after_room_token = self.event_sources.get_current_token()
+
+ last_event = self.get_success(
+ self.store.get_last_event_in_room_before_stream_ordering(
+ room_id=room_id1,
+ end_token=after_room_token.room_key,
+ )
+ )
+
+ # Make sure it's the event we expect (which also means we know it's from the
+ # correct room)
+ self.assertEqual(last_event, event_response["event_id"])
+
+ def test_activity_after_token_has_no_effect(self) -> None:
+ """
+ Test to make sure we return the last event before the token even if there is
+ activity after it.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
+ event_response = self.helper.send(room_id1, "target!", tok=user1_tok)
+
+ after_room_token = self.event_sources.get_current_token()
+
+ # Send some events after the token
+ self.helper.send(room_id1, "after1", tok=user1_tok)
+ self.helper.send(room_id1, "after2", tok=user1_tok)
+
+ last_event = self.get_success(
+ self.store.get_last_event_in_room_before_stream_ordering(
+ room_id=room_id1,
+ end_token=after_room_token.room_key,
+ )
+ )
+
+ # Make sure it's the last event before the token
+ self.assertEqual(last_event, event_response["event_id"])
+
+ def test_last_event_within_sharded_token(self) -> None:
+ """
+ Test to make sure we can find the last event that that is *within* the sharded
+ token (a token that has an `instance_map` and looks like
+ `m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}`). We are specifically testing
+ that we can find an event within the tokens minimum and instance
+ `stream_ordering`.
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
+ event_response1, event_pos1 = self._send_event_on_instance(
+ "worker1", room_id1, user1_tok
+ )
+ event_response2, event_pos2 = self._send_event_on_instance(
+ "worker1", room_id1, user1_tok
+ )
+ event_response3, event_pos3 = self._send_event_on_instance(
+ "worker1", room_id1, user1_tok
+ )
+
+ # Create another room to advance the `stream_ordering` on the same worker
+ # so we can sandwich event3 in the middle of the token
+ room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
+ event_response4, event_pos4 = self._send_event_on_instance(
+ "worker1", room_id2, user1_tok
+ )
+
+ # Assemble a token that encompasses event1 -> event4 on worker1
+ end_token = RoomStreamToken(
+ stream=event_pos2.stream,
+ instance_map=immutabledict({"worker1": event_pos4.stream}),
+ )
+
+ # Send some events after the token
+ self.helper.send(room_id1, "after1", tok=user1_tok)
+ self.helper.send(room_id1, "after2", tok=user1_tok)
+
+ last_event = self.get_success(
+ self.store.get_last_event_in_room_before_stream_ordering(
+ room_id=room_id1,
+ end_token=end_token,
+ )
+ )
+
+ # Should find closest event at/before the token in room1
+ self.assertEqual(
+ last_event,
+ event_response3["event_id"],
+ f"We expected {event_response3['event_id']} but saw {last_event} which corresponds to "
+ + str(
+ {
+ "event1": event_response1["event_id"],
+ "event2": event_response2["event_id"],
+ "event3": event_response3["event_id"],
+ }
+ ),
+ )
+
+ def test_last_event_before_sharded_token(self) -> None:
+ """
+ Test to make sure we can find the last event that is *before* the sharded token
+ (a token that has an `instance_map` and looks like
+ `m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}`).
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
+ event_response1, event_pos1 = self._send_event_on_instance(
+ "worker1", room_id1, user1_tok
+ )
+ event_response2, event_pos2 = self._send_event_on_instance(
+ "worker1", room_id1, user1_tok
+ )
+
+ # Create another room to advance the `stream_ordering` on the same worker
+ room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
+ event_response3, event_pos3 = self._send_event_on_instance(
+ "worker1", room_id2, user1_tok
+ )
+ event_response4, event_pos4 = self._send_event_on_instance(
+ "worker1", room_id2, user1_tok
+ )
+
+ # Assemble a token that encompasses event3 -> event4 on worker1
+ end_token = RoomStreamToken(
+ stream=event_pos3.stream,
+ instance_map=immutabledict({"worker1": event_pos4.stream}),
+ )
+
+ # Send some events after the token
+ self.helper.send(room_id1, "after1", tok=user1_tok)
+ self.helper.send(room_id1, "after2", tok=user1_tok)
+
+ last_event = self.get_success(
+ self.store.get_last_event_in_room_before_stream_ordering(
+ room_id=room_id1,
+ end_token=end_token,
+ )
+ )
+
+ # Should find closest event at/before the token in room1
+ self.assertEqual(
+ last_event,
+ event_response2["event_id"],
+ f"We expected {event_response2['event_id']} but saw {last_event} which corresponds to "
+ + str(
+ {
+ "event1": event_response1["event_id"],
+ "event2": event_response2["event_id"],
+ }
+ ),
+ )
|