diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 0a20f5db4c..bf10743574 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -30,10 +30,12 @@ from typing import (
Mapping,
Optional,
Sequence,
+ Set,
Tuple,
cast,
)
+import attr
from immutabledict import immutabledict
from synapse.api.constants import EduTypes
@@ -65,6 +67,57 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+@attr.s(auto_attribs=True, slots=True, frozen=True)
+class ReceiptInRoom:
+ receipt_type: str
+ user_id: str
+ event_id: str
+ thread_id: Optional[str]
+ data: JsonMapping
+
+ @staticmethod
+ def merge_to_content(receipts: Collection["ReceiptInRoom"]) -> JsonMapping:
+ """Merge the given set of receipts (in a room) into the receipt
+ content format.
+
+ Returns:
+ A mapping of the combined receipts: event ID -> receipt type -> user
+ ID -> receipt data.
+ """
+ # MSC4102: always replace threaded receipts with unthreaded ones if
+ # there is a clash. This means we will drop some receipts, but MSC4102
+ # is designed to drop semantically meaningless receipts, so this is
+ # okay. Previously, we would drop meaningful data!
+ #
+ # We do this by finding the unthreaded receipts, and then filtering out
+ # matching threaded receipts.
+
+ # Set of (user_id, event_id)
+ unthreaded_receipts: Set[Tuple[str, str]] = {
+ (receipt.user_id, receipt.event_id)
+ for receipt in receipts
+ if receipt.thread_id is None
+ }
+
+ # event_id -> receipt_type -> user_id -> receipt data
+ content: Dict[str, Dict[str, Dict[str, JsonMapping]]] = {}
+ for receipt in receipts:
+ data = receipt.data
+ if receipt.thread_id is not None:
+ if (receipt.user_id, receipt.event_id) in unthreaded_receipts:
+ # Ignore threaded receipts if we have an unthreaded one.
+ continue
+
+ data = dict(data)
+ data["thread_id"] = receipt.thread_id
+
+ content.setdefault(receipt.event_id, {}).setdefault(
+ receipt.receipt_type, {}
+ )[receipt.user_id] = data
+
+ return content
+
+
class ReceiptsWorkerStore(SQLBaseStore):
def __init__(
self,
@@ -401,7 +454,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
def f(
txn: LoggingTransaction,
- ) -> List[Tuple[str, str, str, str, Optional[str], str]]:
+ ) -> Mapping[str, Sequence[ReceiptInRoom]]:
if from_key:
sql = """
SELECT stream_id, instance_name, room_id, receipt_type,
@@ -431,50 +484,46 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args))
- return [
- (room_id, receipt_type, user_id, event_id, thread_id, data)
- for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn
- if MultiWriterStreamToken.is_stream_position_in_range(
+ results: Dict[str, List[ReceiptInRoom]] = {}
+ for (
+ stream_id,
+ instance_name,
+ room_id,
+ receipt_type,
+ user_id,
+ event_id,
+ thread_id,
+ data,
+ ) in txn:
+ if not MultiWriterStreamToken.is_stream_position_in_range(
from_key, to_key, instance_name, stream_id
+ ):
+ continue
+
+ results.setdefault(room_id, []).append(
+ ReceiptInRoom(
+ receipt_type=receipt_type,
+ user_id=user_id,
+ event_id=event_id,
+ thread_id=thread_id,
+ data=db_to_json(data),
+ )
)
- ]
+
+ return results
txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
- results: JsonDict = {}
- for room_id, receipt_type, user_id, event_id, thread_id, data in txn_results:
- # We want a single event per room, since we want to batch the
- # receipts by room, event and type.
- room_event = results.setdefault(
- room_id,
- {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
- )
-
- # The content is of the form:
- # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
- event_entry = room_event["content"].setdefault(event_id, {})
- receipt_type_dict = event_entry.setdefault(receipt_type, {})
-
- # MSC4102: always replace threaded receipts with unthreaded ones if there is a clash.
- # Specifically:
- # - if there is no existing receipt, great, set the data.
- # - if there is an existing receipt, is it threaded (thread_id present)?
- # YES: replace if this receipt has no thread id. NO: do not replace.
- # This means we will drop some receipts, but MSC4102 is designed to drop semantically
- # meaningless receipts, so this is okay. Previously, we would drop meaningful data!
- receipt_data = db_to_json(data)
- if user_id in receipt_type_dict: # existing receipt
- # is the existing receipt threaded and we are currently processing an unthreaded one?
- if "thread_id" in receipt_type_dict[user_id] and not thread_id:
- receipt_type_dict[user_id] = (
- receipt_data # replace with unthreaded one
- )
- else: # receipt does not exist, just set it
- receipt_type_dict[user_id] = receipt_data
- if thread_id:
- receipt_type_dict[user_id]["thread_id"] = thread_id
+ results: JsonDict = {
+ room_id: {
+ "room_id": room_id,
+ "type": EduTypes.RECEIPT,
+ "content": ReceiptInRoom.merge_to_content(receipts),
+ }
+ for room_id, receipts in txn_results.items()
+ }
results = {
room_id: [results[room_id]] if room_id in results else []
@@ -485,7 +534,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
async def get_linearized_receipts_for_events(
self,
room_and_event_ids: Collection[Tuple[str, str]],
- ) -> Sequence[JsonMapping]:
+ ) -> Mapping[str, Sequence[ReceiptInRoom]]:
"""Get all receipts for the given set of events.
Arguments:
@@ -495,6 +544,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
Returns:
A list of receipts, one per room.
"""
+ if not room_and_event_ids:
+ return {}
def get_linearized_receipts_for_events_txn(
txn: LoggingTransaction,
@@ -514,8 +565,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
return txn.fetchall()
- # room_id -> event_id -> receipt_type -> user_id -> receipt data
- room_to_content: Dict[str, Dict[str, Dict[str, Dict[str, JsonMapping]]]] = {}
+ # room_id -> receipts
+ room_to_receipts: Dict[str, List[ReceiptInRoom]] = {}
for batch in batch_iter(room_and_event_ids, 1000):
batch_results = await self.db_pool.runInteraction(
"get_linearized_receipts_for_events",
@@ -531,33 +582,17 @@ class ReceiptsWorkerStore(SQLBaseStore):
thread_id,
data,
) in batch_results:
- content = room_to_content.setdefault(room_id, {})
- user_receipts = content.setdefault(event_id, {}).setdefault(
- receipt_type, {}
+ room_to_receipts.setdefault(room_id, []).append(
+ ReceiptInRoom(
+ receipt_type=receipt_type,
+ user_id=user_id,
+ event_id=event_id,
+ thread_id=thread_id,
+ data=db_to_json(data),
+ )
)
- receipt_data = db_to_json(data)
- if thread_id is not None:
- receipt_data["thread_id"] = thread_id
-
- # MSC4102: always replace threaded receipts with unthreaded ones
- # if there is a clash. Specifically:
- # - if there is no existing receipt, great, set the data.
- # - if there is an existing receipt, is it threaded (thread_id
- # present)? YES: replace if this receipt has no thread id.
- # NO: do not replace. This means we will drop some receipts, but
- # MSC4102 is designed to drop semantically meaningless receipts,
- # so this is okay. Previously, we would drop meaningful data!
- if user_id in user_receipts:
- if "thread_id" in user_receipts[user_id] and not thread_id:
- user_receipts[user_id] = receipt_data
- else:
- user_receipts[user_id] = receipt_data
-
- return [
- {"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}
- for room_id, content in room_to_content.items()
- ]
+ return room_to_receipts
@cached(
num_args=2,
@@ -630,6 +665,74 @@ class ReceiptsWorkerStore(SQLBaseStore):
return results
+ async def get_linearized_receipts_for_user_in_rooms(
+ self, user_id: str, room_ids: StrCollection, to_key: MultiWriterStreamToken
+ ) -> Mapping[str, Sequence[ReceiptInRoom]]:
+ """Fetch all receipts for the user in the given room.
+
+ Returns:
+ A dict from room ID to receipts in the room.
+ """
+
+ def get_linearized_receipts_for_user_in_rooms_txn(
+ txn: LoggingTransaction,
+ batch_room_ids: StrCollection,
+ ) -> List[Tuple[str, str, str, str, Optional[str], str]]:
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", batch_room_ids
+ )
+
+ sql = f"""
+ SELECT instance_name, stream_id, room_id, receipt_type, user_id, event_id, thread_id, data
+ FROM receipts_linearized
+ WHERE {clause} AND user_id = ? AND stream_id <= ?
+ """
+
+ args.append(user_id)
+ args.append(to_key.get_max_stream_pos())
+
+ txn.execute(sql, args)
+
+ return [
+ (room_id, receipt_type, user_id, event_id, thread_id, data)
+ for instance_name, stream_id, room_id, receipt_type, user_id, event_id, thread_id, data in txn
+ if MultiWriterStreamToken.is_stream_position_in_range(
+ low=None,
+ high=to_key,
+ instance_name=instance_name,
+ pos=stream_id,
+ )
+ ]
+
+ # room_id -> receipts
+ room_to_receipts: Dict[str, List[ReceiptInRoom]] = {}
+ for batch in batch_iter(room_ids, 1000):
+ batch_results = await self.db_pool.runInteraction(
+ "get_linearized_receipts_for_events",
+ get_linearized_receipts_for_user_in_rooms_txn,
+ batch,
+ )
+
+ for (
+ room_id,
+ receipt_type,
+ user_id,
+ event_id,
+ thread_id,
+ data,
+ ) in batch_results:
+ room_to_receipts.setdefault(room_id, []).append(
+ ReceiptInRoom(
+ receipt_type=receipt_type,
+ user_id=user_id,
+ event_id=event_id,
+ thread_id=thread_id,
+ data=db_to_json(data),
+ )
+ )
+
+ return room_to_receipts
+
async def get_rooms_with_receipts_between(
self,
room_ids: StrCollection,
|