diff --git a/changelog.d/17617.misc b/changelog.d/17617.misc
new file mode 100644
index 0000000000..ba05648965
--- /dev/null
+++ b/changelog.d/17617.misc
@@ -0,0 +1 @@
+Always return the user's own read receipts in sliding sync.
diff --git a/synapse/handlers/sliding_sync/extensions.py b/synapse/handlers/sliding_sync/extensions.py
index f05f45f72c..a2d4f24f9c 100644
--- a/synapse/handlers/sliding_sync/extensions.py
+++ b/synapse/handlers/sliding_sync/extensions.py
@@ -12,12 +12,13 @@
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
+import itertools
import logging
from typing import TYPE_CHECKING, AbstractSet, Dict, Mapping, Optional, Sequence, Set
from typing_extensions import assert_never
-from synapse.api.constants import AccountDataTypes
+from synapse.api.constants import AccountDataTypes, EduTypes
from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.sliding_sync.types import (
HaveSentRoomFlag,
@@ -25,6 +26,7 @@ from synapse.handlers.sliding_sync.types import (
PerConnectionState,
)
from synapse.logging.opentracing import trace
+from synapse.storage.databases.main.receipts import ReceiptInRoom
from synapse.types import (
DeviceListUpdates,
JsonMapping,
@@ -485,15 +487,21 @@ class SlidingSyncExtensionHandler:
initial_rooms.add(room_id)
continue
- # If we're sending down the room from scratch again for some reason, we
- # should always resend the receipts as well (regardless of if
- # we've sent them down before). This is to mimic the behaviour
- # of what happens on initial sync, where you get a chunk of
- # timeline with all of the corresponding receipts for the events in the timeline.
+ # If we're sending down the room from scratch again for some
+ # reason, we should always resend the receipts as well
+ # (regardless of if we've sent them down before). This is to
+ # mimic the behaviour of what happens on initial sync, where you
+ # get a chunk of timeline with all of the corresponding receipts
+ # for the events in the timeline.
+ #
+ # We also resend down receipts when we "expand" the timeline,
+ # (see the "XXX: Odd behavior" in
+ # `synapse.handlers.sliding_sync`).
room_result = actual_room_response_map.get(room_id)
- if room_result is not None and room_result.initial:
- initial_rooms.add(room_id)
- continue
+ if room_result is not None:
+ if room_result.initial or room_result.unstable_expanded_timeline:
+ initial_rooms.add(room_id)
+ continue
room_status = previous_connection_state.receipts.have_sent_room(room_id)
if room_status.status == HaveSentRoomFlag.LIVE:
@@ -536,21 +544,49 @@ class SlidingSyncExtensionHandler:
)
fetched_receipts.extend(previously_receipts)
- # For rooms we haven't previously sent down, we could send all receipts
- # from that room but we only want to include receipts for events
- # in the timeline to avoid bloating and blowing up the sync response
- # as the number of users in the room increases. (this behavior is part of the spec)
- initial_rooms_and_event_ids = [
- (room_id, event.event_id)
- for room_id in initial_rooms
- if room_id in actual_room_response_map
- for event in actual_room_response_map[room_id].timeline_events
- ]
- if initial_rooms_and_event_ids:
+ if initial_rooms:
+ # We also always send down receipts for the current user.
+ user_receipts = (
+ await self.store.get_linearized_receipts_for_user_in_rooms(
+ user_id=sync_config.user.to_string(),
+ room_ids=initial_rooms,
+ to_key=to_token.receipt_key,
+ )
+ )
+
+ # For rooms we haven't previously sent down, we could send all receipts
+ # from that room but we only want to include receipts for events
+ # in the timeline to avoid bloating and blowing up the sync response
+ # as the number of users in the room increases. (this behavior is part of the spec)
+ initial_rooms_and_event_ids = [
+ (room_id, event.event_id)
+ for room_id in initial_rooms
+ if room_id in actual_room_response_map
+ for event in actual_room_response_map[room_id].timeline_events
+ ]
initial_receipts = await self.store.get_linearized_receipts_for_events(
room_and_event_ids=initial_rooms_and_event_ids,
)
- fetched_receipts.extend(initial_receipts)
+
+ # Combine the receipts for a room and add them to
+ # `fetched_receipts`
+ for room_id in initial_receipts.keys() | user_receipts.keys():
+ receipt_content = ReceiptInRoom.merge_to_content(
+ list(
+ itertools.chain(
+ initial_receipts.get(room_id, []),
+ user_receipts.get(room_id, []),
+ )
+ )
+ )
+
+ fetched_receipts.append(
+ {
+ "room_id": room_id,
+ "type": EduTypes.RECEIPT,
+ "content": receipt_content,
+ }
+ )
fetched_receipts = ReceiptEventSource.filter_out_private_receipts(
fetched_receipts, sync_config.user.to_string()
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 0a20f5db4c..bf10743574 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -30,10 +30,12 @@ from typing import (
Mapping,
Optional,
Sequence,
+ Set,
Tuple,
cast,
)
+import attr
from immutabledict import immutabledict
from synapse.api.constants import EduTypes
@@ -65,6 +67,57 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+@attr.s(auto_attribs=True, slots=True, frozen=True)
+class ReceiptInRoom:
+ receipt_type: str
+ user_id: str
+ event_id: str
+ thread_id: Optional[str]
+ data: JsonMapping
+
+ @staticmethod
+ def merge_to_content(receipts: Collection["ReceiptInRoom"]) -> JsonMapping:
+ """Merge the given set of receipts (in a room) into the receipt
+ content format.
+
+ Returns:
+ A mapping of the combined receipts: event ID -> receipt type -> user
+ ID -> receipt data.
+ """
+ # MSC4102: always replace threaded receipts with unthreaded ones if
+ # there is a clash. This means we will drop some receipts, but MSC4102
+ # is designed to drop semantically meaningless receipts, so this is
+ # okay. Previously, we would drop meaningful data!
+ #
+ # We do this by finding the unthreaded receipts, and then filtering out
+ # matching threaded receipts.
+
+ # Set of (user_id, event_id)
+ unthreaded_receipts: Set[Tuple[str, str]] = {
+ (receipt.user_id, receipt.event_id)
+ for receipt in receipts
+ if receipt.thread_id is None
+ }
+
+ # event_id -> receipt_type -> user_id -> receipt data
+ content: Dict[str, Dict[str, Dict[str, JsonMapping]]] = {}
+ for receipt in receipts:
+ data = receipt.data
+ if receipt.thread_id is not None:
+ if (receipt.user_id, receipt.event_id) in unthreaded_receipts:
+ # Ignore threaded receipts if we have an unthreaded one.
+ continue
+
+ data = dict(data)
+ data["thread_id"] = receipt.thread_id
+
+ content.setdefault(receipt.event_id, {}).setdefault(
+ receipt.receipt_type, {}
+ )[receipt.user_id] = data
+
+ return content
+
+
class ReceiptsWorkerStore(SQLBaseStore):
def __init__(
self,
@@ -401,7 +454,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
def f(
txn: LoggingTransaction,
- ) -> List[Tuple[str, str, str, str, Optional[str], str]]:
+ ) -> Mapping[str, Sequence[ReceiptInRoom]]:
if from_key:
sql = """
SELECT stream_id, instance_name, room_id, receipt_type,
@@ -431,50 +484,46 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args))
- return [
- (room_id, receipt_type, user_id, event_id, thread_id, data)
- for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn
- if MultiWriterStreamToken.is_stream_position_in_range(
+ results: Dict[str, List[ReceiptInRoom]] = {}
+ for (
+ stream_id,
+ instance_name,
+ room_id,
+ receipt_type,
+ user_id,
+ event_id,
+ thread_id,
+ data,
+ ) in txn:
+ if not MultiWriterStreamToken.is_stream_position_in_range(
from_key, to_key, instance_name, stream_id
+ ):
+ continue
+
+ results.setdefault(room_id, []).append(
+ ReceiptInRoom(
+ receipt_type=receipt_type,
+ user_id=user_id,
+ event_id=event_id,
+ thread_id=thread_id,
+ data=db_to_json(data),
+ )
)
- ]
+
+ return results
txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
- results: JsonDict = {}
- for room_id, receipt_type, user_id, event_id, thread_id, data in txn_results:
- # We want a single event per room, since we want to batch the
- # receipts by room, event and type.
- room_event = results.setdefault(
- room_id,
- {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
- )
-
- # The content is of the form:
- # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
- event_entry = room_event["content"].setdefault(event_id, {})
- receipt_type_dict = event_entry.setdefault(receipt_type, {})
-
- # MSC4102: always replace threaded receipts with unthreaded ones if there is a clash.
- # Specifically:
- # - if there is no existing receipt, great, set the data.
- # - if there is an existing receipt, is it threaded (thread_id present)?
- # YES: replace if this receipt has no thread id. NO: do not replace.
- # This means we will drop some receipts, but MSC4102 is designed to drop semantically
- # meaningless receipts, so this is okay. Previously, we would drop meaningful data!
- receipt_data = db_to_json(data)
- if user_id in receipt_type_dict: # existing receipt
- # is the existing receipt threaded and we are currently processing an unthreaded one?
- if "thread_id" in receipt_type_dict[user_id] and not thread_id:
- receipt_type_dict[user_id] = (
- receipt_data # replace with unthreaded one
- )
- else: # receipt does not exist, just set it
- receipt_type_dict[user_id] = receipt_data
- if thread_id:
- receipt_type_dict[user_id]["thread_id"] = thread_id
+ results: JsonDict = {
+ room_id: {
+ "room_id": room_id,
+ "type": EduTypes.RECEIPT,
+ "content": ReceiptInRoom.merge_to_content(receipts),
+ }
+ for room_id, receipts in txn_results.items()
+ }
results = {
room_id: [results[room_id]] if room_id in results else []
@@ -485,7 +534,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
async def get_linearized_receipts_for_events(
self,
room_and_event_ids: Collection[Tuple[str, str]],
- ) -> Sequence[JsonMapping]:
+ ) -> Mapping[str, Sequence[ReceiptInRoom]]:
"""Get all receipts for the given set of events.
Arguments:
@@ -495,6 +544,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
Returns:
A list of receipts, one per room.
"""
+ if not room_and_event_ids:
+ return {}
def get_linearized_receipts_for_events_txn(
txn: LoggingTransaction,
@@ -514,8 +565,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
return txn.fetchall()
- # room_id -> event_id -> receipt_type -> user_id -> receipt data
- room_to_content: Dict[str, Dict[str, Dict[str, Dict[str, JsonMapping]]]] = {}
+ # room_id -> receipts
+ room_to_receipts: Dict[str, List[ReceiptInRoom]] = {}
for batch in batch_iter(room_and_event_ids, 1000):
batch_results = await self.db_pool.runInteraction(
"get_linearized_receipts_for_events",
@@ -531,33 +582,17 @@ class ReceiptsWorkerStore(SQLBaseStore):
thread_id,
data,
) in batch_results:
- content = room_to_content.setdefault(room_id, {})
- user_receipts = content.setdefault(event_id, {}).setdefault(
- receipt_type, {}
+ room_to_receipts.setdefault(room_id, []).append(
+ ReceiptInRoom(
+ receipt_type=receipt_type,
+ user_id=user_id,
+ event_id=event_id,
+ thread_id=thread_id,
+ data=db_to_json(data),
+ )
)
- receipt_data = db_to_json(data)
- if thread_id is not None:
- receipt_data["thread_id"] = thread_id
-
- # MSC4102: always replace threaded receipts with unthreaded ones
- # if there is a clash. Specifically:
- # - if there is no existing receipt, great, set the data.
- # - if there is an existing receipt, is it threaded (thread_id
- # present)? YES: replace if this receipt has no thread id.
- # NO: do not replace. This means we will drop some receipts, but
- # MSC4102 is designed to drop semantically meaningless receipts,
- # so this is okay. Previously, we would drop meaningful data!
- if user_id in user_receipts:
- if "thread_id" in user_receipts[user_id] and not thread_id:
- user_receipts[user_id] = receipt_data
- else:
- user_receipts[user_id] = receipt_data
-
- return [
- {"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}
- for room_id, content in room_to_content.items()
- ]
+ return room_to_receipts
@cached(
num_args=2,
@@ -630,6 +665,74 @@ class ReceiptsWorkerStore(SQLBaseStore):
return results
+ async def get_linearized_receipts_for_user_in_rooms(
+ self, user_id: str, room_ids: StrCollection, to_key: MultiWriterStreamToken
+ ) -> Mapping[str, Sequence[ReceiptInRoom]]:
+ """Fetch all receipts for the user in the given room.
+
+ Returns:
+ A dict from room ID to receipts in the room.
+ """
+
+ def get_linearized_receipts_for_user_in_rooms_txn(
+ txn: LoggingTransaction,
+ batch_room_ids: StrCollection,
+ ) -> List[Tuple[str, str, str, str, Optional[str], str]]:
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", batch_room_ids
+ )
+
+ sql = f"""
+ SELECT instance_name, stream_id, room_id, receipt_type, user_id, event_id, thread_id, data
+ FROM receipts_linearized
+ WHERE {clause} AND user_id = ? AND stream_id <= ?
+ """
+
+ args.append(user_id)
+ args.append(to_key.get_max_stream_pos())
+
+ txn.execute(sql, args)
+
+ return [
+ (room_id, receipt_type, user_id, event_id, thread_id, data)
+ for instance_name, stream_id, room_id, receipt_type, user_id, event_id, thread_id, data in txn
+ if MultiWriterStreamToken.is_stream_position_in_range(
+ low=None,
+ high=to_key,
+ instance_name=instance_name,
+ pos=stream_id,
+ )
+ ]
+
+ # room_id -> receipts
+ room_to_receipts: Dict[str, List[ReceiptInRoom]] = {}
+ for batch in batch_iter(room_ids, 1000):
+ batch_results = await self.db_pool.runInteraction(
+ "get_linearized_receipts_for_events",
+ get_linearized_receipts_for_user_in_rooms_txn,
+ batch,
+ )
+
+ for (
+ room_id,
+ receipt_type,
+ user_id,
+ event_id,
+ thread_id,
+ data,
+ ) in batch_results:
+ room_to_receipts.setdefault(room_id, []).append(
+ ReceiptInRoom(
+ receipt_type=receipt_type,
+ user_id=user_id,
+ event_id=event_id,
+ thread_id=thread_id,
+ data=db_to_json(data),
+ )
+ )
+
+ return room_to_receipts
+
async def get_rooms_with_receipts_between(
self,
room_ids: StrCollection,
diff --git a/tests/rest/client/sliding_sync/test_extension_receipts.py b/tests/rest/client/sliding_sync/test_extension_receipts.py
index 39c51b367c..e842349ed2 100644
--- a/tests/rest/client/sliding_sync/test_extension_receipts.py
+++ b/tests/rest/client/sliding_sync/test_extension_receipts.py
@@ -782,3 +782,135 @@ class SlidingSyncReceiptsExtensionTestCase(SlidingSyncBase):
{user2_id},
exact=True,
)
+
+ def test_return_own_read_receipts(self) -> None:
+ """Test that we always send the user's own read receipts in initial
+ rooms, even if the receipts don't match events in the timeline..
+ """
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ # Send a message and read receipts into room1
+ event_response = self.helper.send(room_id1, body="new event", tok=user2_tok)
+ room1_event_id = event_response["event_id"]
+
+ self.helper.send_read_receipt(room_id1, room1_event_id, tok=user1_tok)
+ self.helper.send_read_receipt(room_id1, room1_event_id, tok=user2_tok)
+
+ # Now send a message so the above message is not in the timeline.
+ self.helper.send(room_id1, body="new event", tok=user2_tok)
+
+ # Make a SS request for only the latest message.
+ sync_body = {
+ "lists": {
+ "main": {
+ "ranges": [[0, 0]],
+ "required_state": [],
+ "timeline_limit": 1,
+ }
+ },
+ "extensions": {
+ "receipts": {
+ "enabled": True,
+ }
+ },
+ }
+ response_body, _ = self.do_sync(sync_body, tok=user1_tok)
+
+ # We should get our own receipt in room1, even though its not in the
+ # timeline limit.
+ self.assertIncludes(
+ response_body["extensions"]["receipts"].get("rooms").keys(),
+ {room_id1},
+ exact=True,
+ )
+
+ # We should only see our read receipt, not the other user's.
+ receipt = response_body["extensions"]["receipts"]["rooms"][room_id1]
+ self.assertIncludes(
+ receipt["content"][room1_event_id][ReceiptTypes.READ].keys(),
+ {user1_id},
+ exact=True,
+ )
+
+ def test_read_receipts_expanded_timeline(self) -> None:
+ """Test that we get read receipts when we expand the timeline limit (`unstable_expanded_timeline`)."""
+
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+ # Send a message and read receipt into room1
+ event_response = self.helper.send(room_id1, body="new event", tok=user2_tok)
+ room1_event_id = event_response["event_id"]
+
+ self.helper.send_read_receipt(room_id1, room1_event_id, tok=user2_tok)
+
+ # Now send a message so the above message is not in the timeline.
+ self.helper.send(room_id1, body="new event", tok=user2_tok)
+
+ # Make a SS request for only the latest message.
+ sync_body = {
+ "lists": {
+ "main": {
+ "ranges": [[0, 0]],
+ "required_state": [],
+ "timeline_limit": 1,
+ }
+ },
+ "extensions": {
+ "receipts": {
+ "enabled": True,
+ }
+ },
+ }
+ response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+ # We shouldn't see user2 read receipt, as its not in the timeline
+ self.assertIncludes(
+ response_body["extensions"]["receipts"].get("rooms").keys(),
+ set(),
+ exact=True,
+ )
+
+ # Now do another request with a room subscription with an increased timeline limit
+ sync_body["room_subscriptions"] = {
+ room_id1: {
+ "required_state": [],
+ "timeline_limit": 2,
+ }
+ }
+
+ response_body, from_token = self.do_sync(
+ sync_body, since=from_token, tok=user1_tok
+ )
+
+ # Assert that we did actually get an expanded timeline
+ room_response = response_body["rooms"][room_id1]
+ self.assertNotIn("initial", room_response)
+ self.assertEqual(room_response["unstable_expanded_timeline"], True)
+
+ # We should now see user2 read receipt, as its in the expanded timeline
+ self.assertIncludes(
+ response_body["extensions"]["receipts"].get("rooms").keys(),
+ {room_id1},
+ exact=True,
+ )
+
+ # We should only see our read receipt, not the other user's.
+ receipt = response_body["extensions"]["receipts"]["rooms"][room_id1]
+ self.assertIncludes(
+ receipt["content"][room1_event_id][ReceiptTypes.READ].keys(),
+ {user2_id},
+ exact=True,
+ )
|