summary refs log tree commit diff
path: root/synapse/storage/databases/main/receipts.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/receipts.py')
-rw-r--r--synapse/storage/databases/main/receipts.py235
1 files changed, 169 insertions, 66 deletions
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py

index 0a20f5db4c..bf10743574 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py
@@ -30,10 +30,12 @@ from typing import ( Mapping, Optional, Sequence, + Set, Tuple, cast, ) +import attr from immutabledict import immutabledict from synapse.api.constants import EduTypes @@ -65,6 +67,57 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@attr.s(auto_attribs=True, slots=True, frozen=True) +class ReceiptInRoom: + receipt_type: str + user_id: str + event_id: str + thread_id: Optional[str] + data: JsonMapping + + @staticmethod + def merge_to_content(receipts: Collection["ReceiptInRoom"]) -> JsonMapping: + """Merge the given set of receipts (in a room) into the receipt + content format. + + Returns: + A mapping of the combined receipts: event ID -> receipt type -> user + ID -> receipt data. + """ + # MSC4102: always replace threaded receipts with unthreaded ones if + # there is a clash. This means we will drop some receipts, but MSC4102 + # is designed to drop semantically meaningless receipts, so this is + # okay. Previously, we would drop meaningful data! + # + # We do this by finding the unthreaded receipts, and then filtering out + # matching threaded receipts. + + # Set of (user_id, event_id) + unthreaded_receipts: Set[Tuple[str, str]] = { + (receipt.user_id, receipt.event_id) + for receipt in receipts + if receipt.thread_id is None + } + + # event_id -> receipt_type -> user_id -> receipt data + content: Dict[str, Dict[str, Dict[str, JsonMapping]]] = {} + for receipt in receipts: + data = receipt.data + if receipt.thread_id is not None: + if (receipt.user_id, receipt.event_id) in unthreaded_receipts: + # Ignore threaded receipts if we have an unthreaded one. + continue + + data = dict(data) + data["thread_id"] = receipt.thread_id + + content.setdefault(receipt.event_id, {}).setdefault( + receipt.receipt_type, {} + )[receipt.user_id] = data + + return content + + class ReceiptsWorkerStore(SQLBaseStore): def __init__( self, @@ -401,7 +454,7 @@ class ReceiptsWorkerStore(SQLBaseStore): def f( txn: LoggingTransaction, - ) -> List[Tuple[str, str, str, str, Optional[str], str]]: + ) -> Mapping[str, Sequence[ReceiptInRoom]]: if from_key: sql = """ SELECT stream_id, instance_name, room_id, receipt_type, @@ -431,50 +484,46 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args)) - return [ - (room_id, receipt_type, user_id, event_id, thread_id, data) - for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn - if MultiWriterStreamToken.is_stream_position_in_range( + results: Dict[str, List[ReceiptInRoom]] = {} + for ( + stream_id, + instance_name, + room_id, + receipt_type, + user_id, + event_id, + thread_id, + data, + ) in txn: + if not MultiWriterStreamToken.is_stream_position_in_range( from_key, to_key, instance_name, stream_id + ): + continue + + results.setdefault(room_id, []).append( + ReceiptInRoom( + receipt_type=receipt_type, + user_id=user_id, + event_id=event_id, + thread_id=thread_id, + data=db_to_json(data), + ) ) - ] + + return results txn_results = await self.db_pool.runInteraction( "_get_linearized_receipts_for_rooms", f ) - results: JsonDict = {} - for room_id, receipt_type, user_id, event_id, thread_id, data in txn_results: - # We want a single event per room, since we want to batch the - # receipts by room, event and type. - room_event = results.setdefault( - room_id, - {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}}, - ) - - # The content is of the form: - # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } - event_entry = room_event["content"].setdefault(event_id, {}) - receipt_type_dict = event_entry.setdefault(receipt_type, {}) - - # MSC4102: always replace threaded receipts with unthreaded ones if there is a clash. - # Specifically: - # - if there is no existing receipt, great, set the data. - # - if there is an existing receipt, is it threaded (thread_id present)? - # YES: replace if this receipt has no thread id. NO: do not replace. - # This means we will drop some receipts, but MSC4102 is designed to drop semantically - # meaningless receipts, so this is okay. Previously, we would drop meaningful data! - receipt_data = db_to_json(data) - if user_id in receipt_type_dict: # existing receipt - # is the existing receipt threaded and we are currently processing an unthreaded one? - if "thread_id" in receipt_type_dict[user_id] and not thread_id: - receipt_type_dict[user_id] = ( - receipt_data # replace with unthreaded one - ) - else: # receipt does not exist, just set it - receipt_type_dict[user_id] = receipt_data - if thread_id: - receipt_type_dict[user_id]["thread_id"] = thread_id + results: JsonDict = { + room_id: { + "room_id": room_id, + "type": EduTypes.RECEIPT, + "content": ReceiptInRoom.merge_to_content(receipts), + } + for room_id, receipts in txn_results.items() + } results = { room_id: [results[room_id]] if room_id in results else [] @@ -485,7 +534,7 @@ class ReceiptsWorkerStore(SQLBaseStore): async def get_linearized_receipts_for_events( self, room_and_event_ids: Collection[Tuple[str, str]], - ) -> Sequence[JsonMapping]: + ) -> Mapping[str, Sequence[ReceiptInRoom]]: """Get all receipts for the given set of events. Arguments: @@ -495,6 +544,8 @@ class ReceiptsWorkerStore(SQLBaseStore): Returns: A list of receipts, one per room. """ + if not room_and_event_ids: + return {} def get_linearized_receipts_for_events_txn( txn: LoggingTransaction, @@ -514,8 +565,8 @@ class ReceiptsWorkerStore(SQLBaseStore): return txn.fetchall() - # room_id -> event_id -> receipt_type -> user_id -> receipt data - room_to_content: Dict[str, Dict[str, Dict[str, Dict[str, JsonMapping]]]] = {} + # room_id -> receipts + room_to_receipts: Dict[str, List[ReceiptInRoom]] = {} for batch in batch_iter(room_and_event_ids, 1000): batch_results = await self.db_pool.runInteraction( "get_linearized_receipts_for_events", @@ -531,33 +582,17 @@ class ReceiptsWorkerStore(SQLBaseStore): thread_id, data, ) in batch_results: - content = room_to_content.setdefault(room_id, {}) - user_receipts = content.setdefault(event_id, {}).setdefault( - receipt_type, {} + room_to_receipts.setdefault(room_id, []).append( + ReceiptInRoom( + receipt_type=receipt_type, + user_id=user_id, + event_id=event_id, + thread_id=thread_id, + data=db_to_json(data), + ) ) - receipt_data = db_to_json(data) - if thread_id is not None: - receipt_data["thread_id"] = thread_id - - # MSC4102: always replace threaded receipts with unthreaded ones - # if there is a clash. Specifically: - # - if there is no existing receipt, great, set the data. - # - if there is an existing receipt, is it threaded (thread_id - # present)? YES: replace if this receipt has no thread id. - # NO: do not replace. This means we will drop some receipts, but - # MSC4102 is designed to drop semantically meaningless receipts, - # so this is okay. Previously, we would drop meaningful data! - if user_id in user_receipts: - if "thread_id" in user_receipts[user_id] and not thread_id: - user_receipts[user_id] = receipt_data - else: - user_receipts[user_id] = receipt_data - - return [ - {"type": EduTypes.RECEIPT, "room_id": room_id, "content": content} - for room_id, content in room_to_content.items() - ] + return room_to_receipts @cached( num_args=2, @@ -630,6 +665,74 @@ class ReceiptsWorkerStore(SQLBaseStore): return results + async def get_linearized_receipts_for_user_in_rooms( + self, user_id: str, room_ids: StrCollection, to_key: MultiWriterStreamToken + ) -> Mapping[str, Sequence[ReceiptInRoom]]: + """Fetch all receipts for the user in the given room. + + Returns: + A dict from room ID to receipts in the room. + """ + + def get_linearized_receipts_for_user_in_rooms_txn( + txn: LoggingTransaction, + batch_room_ids: StrCollection, + ) -> List[Tuple[str, str, str, str, Optional[str], str]]: + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", batch_room_ids + ) + + sql = f""" + SELECT instance_name, stream_id, room_id, receipt_type, user_id, event_id, thread_id, data + FROM receipts_linearized + WHERE {clause} AND user_id = ? AND stream_id <= ? + """ + + args.append(user_id) + args.append(to_key.get_max_stream_pos()) + + txn.execute(sql, args) + + return [ + (room_id, receipt_type, user_id, event_id, thread_id, data) + for instance_name, stream_id, room_id, receipt_type, user_id, event_id, thread_id, data in txn + if MultiWriterStreamToken.is_stream_position_in_range( + low=None, + high=to_key, + instance_name=instance_name, + pos=stream_id, + ) + ] + + # room_id -> receipts + room_to_receipts: Dict[str, List[ReceiptInRoom]] = {} + for batch in batch_iter(room_ids, 1000): + batch_results = await self.db_pool.runInteraction( + "get_linearized_receipts_for_events", + get_linearized_receipts_for_user_in_rooms_txn, + batch, + ) + + for ( + room_id, + receipt_type, + user_id, + event_id, + thread_id, + data, + ) in batch_results: + room_to_receipts.setdefault(room_id, []).append( + ReceiptInRoom( + receipt_type=receipt_type, + user_id=user_id, + event_id=event_id, + thread_id=thread_id, + data=db_to_json(data), + ) + ) + + return room_to_receipts + async def get_rooms_with_receipts_between( self, room_ids: StrCollection,