diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 7ab6003f61..61373f0bfb 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -914,12 +914,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def get_last_event_in_room_before_stream_ordering_txn(
txn: LoggingTransaction,
) -> Optional[str]:
- # We need to handle the fact that the stream tokens can be vector
- # clocks. We do this by getting all rows between the minimum and
- # maximum stream ordering in the token, plus one row less than the
- # minimum stream ordering. We then filter the results against the
- # token and return the first row that matches.
-
+ # We're looking for the closest event at or before the token. We need to
+ # handle the fact that the stream token can be a vector clock (with an
+ # `instance_map`) and events can be persisted on different instances
+ # (sharded event persisters). The first subquery handles the events that
+ # would be within the vector clock and gets all rows between the minimum and
+ # maximum stream ordering in the token which need to be filtered against the
+ # `instance_map`. The second subquery handles the "before" case and finds
+ # the first row before the token. We then filter out any results past the
+ # token's vector clock and return the first row that matches.
+ min_stream = end_token.stream
+ max_stream = end_token.get_max_stream_pos()
+
+ # We use `union all` because we don't need any of the deduplication logic
+ # (`union` is really a union + distinct). `UNION ALL` does preserve the
+ # ordering of the operand queries but there is no actual gurantee that it
+ # has this behavior in all scenarios so we need the extra `ORDER BY` at the
+ # bottom.
sql = """
SELECT * FROM (
SELECT instance_name, stream_ordering, topological_ordering, event_id
@@ -931,7 +942,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
AND rejections.event_id IS NULL
ORDER BY stream_ordering DESC
) AS a
- UNION
+ UNION ALL
SELECT * FROM (
SELECT instance_name, stream_ordering, topological_ordering, event_id
FROM events
@@ -943,15 +954,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
ORDER BY stream_ordering DESC
LIMIT 1
) AS b
+ ORDER BY stream_ordering DESC
"""
txn.execute(
sql,
(
room_id,
- end_token.stream,
- end_token.get_max_stream_pos(),
+ min_stream,
+ max_stream,
room_id,
- end_token.stream,
+ min_stream,
),
)
|