diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 3bde0ae0d4..9964331510 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -30,10 +30,12 @@ from typing import (
Mapping,
Optional,
Sequence,
+ Set,
Tuple,
cast,
)
+import attr
from immutabledict import immutabledict
from synapse.api.constants import EduTypes
@@ -43,6 +45,7 @@ from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
+ make_tuple_in_list_sql_clause,
)
from synapse.storage.engines._base import IsolationLevel
from synapse.storage.util.id_generators import MultiWriterIdGenerator
@@ -51,10 +54,12 @@ from synapse.types import (
JsonMapping,
MultiWriterStreamToken,
PersistedPosition,
+ StrCollection,
)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -62,6 +67,57 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+@attr.s(auto_attribs=True, slots=True, frozen=True)
+class ReceiptInRoom:
+ receipt_type: str
+ user_id: str
+ event_id: str
+ thread_id: Optional[str]
+ data: JsonMapping
+
+ @staticmethod
+ def merge_to_content(receipts: Collection["ReceiptInRoom"]) -> JsonMapping:
+ """Merge the given set of receipts (in a room) into the receipt
+ content format.
+
+ Returns:
+ A mapping of the combined receipts: event ID -> receipt type -> user
+ ID -> receipt data.
+ """
+ # MSC4102: always replace threaded receipts with unthreaded ones if
+ # there is a clash. This means we will drop some receipts, but MSC4102
+ # is designed to drop semantically meaningless receipts, so this is
+ # okay. Previously, we would drop meaningful data!
+ #
+ # We do this by finding the unthreaded receipts, and then filtering out
+ # matching threaded receipts.
+
+ # Set of (user_id, event_id)
+ unthreaded_receipts: Set[Tuple[str, str]] = {
+ (receipt.user_id, receipt.event_id)
+ for receipt in receipts
+ if receipt.thread_id is None
+ }
+
+ # event_id -> receipt_type -> user_id -> receipt data
+ content: Dict[str, Dict[str, Dict[str, JsonMapping]]] = {}
+ for receipt in receipts:
+ data = receipt.data
+ if receipt.thread_id is not None:
+ if (receipt.user_id, receipt.event_id) in unthreaded_receipts:
+ # Ignore threaded receipts if we have an unthreaded one.
+ continue
+
+ data = dict(data)
+ data["thread_id"] = receipt.thread_id
+
+ content.setdefault(receipt.event_id, {}).setdefault(
+ receipt.receipt_type, {}
+ )[receipt.user_id] = data
+
+ return content
+
+
class ReceiptsWorkerStore(SQLBaseStore):
def __init__(
self,
@@ -398,7 +454,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
def f(
txn: LoggingTransaction,
- ) -> List[Tuple[str, str, str, str, Optional[str], str]]:
+ ) -> Mapping[str, Sequence[ReceiptInRoom]]:
if from_key:
sql = """
SELECT stream_id, instance_name, room_id, receipt_type,
@@ -428,50 +484,46 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args))
- return [
- (room_id, receipt_type, user_id, event_id, thread_id, data)
- for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn
- if MultiWriterStreamToken.is_stream_position_in_range(
+ results: Dict[str, List[ReceiptInRoom]] = {}
+ for (
+ stream_id,
+ instance_name,
+ room_id,
+ receipt_type,
+ user_id,
+ event_id,
+ thread_id,
+ data,
+ ) in txn:
+ if not MultiWriterStreamToken.is_stream_position_in_range(
from_key, to_key, instance_name, stream_id
+ ):
+ continue
+
+ results.setdefault(room_id, []).append(
+ ReceiptInRoom(
+ receipt_type=receipt_type,
+ user_id=user_id,
+ event_id=event_id,
+ thread_id=thread_id,
+ data=db_to_json(data),
+ )
)
- ]
+
+ return results
txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
- results: JsonDict = {}
- for room_id, receipt_type, user_id, event_id, thread_id, data in txn_results:
- # We want a single event per room, since we want to batch the
- # receipts by room, event and type.
- room_event = results.setdefault(
- room_id,
- {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
- )
-
- # The content is of the form:
- # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
- event_entry = room_event["content"].setdefault(event_id, {})
- receipt_type_dict = event_entry.setdefault(receipt_type, {})
-
- # MSC4102: always replace threaded receipts with unthreaded ones if there is a clash.
- # Specifically:
- # - if there is no existing receipt, great, set the data.
- # - if there is an existing receipt, is it threaded (thread_id present)?
- # YES: replace if this receipt has no thread id. NO: do not replace.
- # This means we will drop some receipts, but MSC4102 is designed to drop semantically
- # meaningless receipts, so this is okay. Previously, we would drop meaningful data!
- receipt_data = db_to_json(data)
- if user_id in receipt_type_dict: # existing receipt
- # is the existing receipt threaded and we are currently processing an unthreaded one?
- if "thread_id" in receipt_type_dict[user_id] and not thread_id:
- receipt_type_dict[user_id] = (
- receipt_data # replace with unthreaded one
- )
- else: # receipt does not exist, just set it
- receipt_type_dict[user_id] = receipt_data
- if thread_id:
- receipt_type_dict[user_id]["thread_id"] = thread_id
+ results: JsonDict = {
+ room_id: {
+ "room_id": room_id,
+ "type": EduTypes.RECEIPT,
+ "content": ReceiptInRoom.merge_to_content(receipts),
+ }
+ for room_id, receipts in txn_results.items()
+ }
results = {
room_id: [results[room_id]] if room_id in results else []
@@ -479,6 +531,69 @@ class ReceiptsWorkerStore(SQLBaseStore):
}
return results
+ async def get_linearized_receipts_for_events(
+ self,
+ room_and_event_ids: Collection[Tuple[str, str]],
+ ) -> Mapping[str, Sequence[ReceiptInRoom]]:
+ """Get all receipts for the given set of events.
+
+ Arguments:
+ room_and_event_ids: A collection of 2-tuples of room ID and
+ event IDs to fetch receipts for
+
+ Returns:
+ A list of receipts, one per room.
+ """
+ if not room_and_event_ids:
+ return {}
+
+ def get_linearized_receipts_for_events_txn(
+ txn: LoggingTransaction,
+ room_id_event_id_tuples: Collection[Tuple[str, str]],
+ ) -> List[Tuple[str, str, str, str, Optional[str], str]]:
+ clause, args = make_tuple_in_list_sql_clause(
+ self.database_engine, ("room_id", "event_id"), room_id_event_id_tuples
+ )
+
+ sql = f"""
+ SELECT room_id, receipt_type, user_id, event_id, thread_id, data
+ FROM receipts_linearized
+ WHERE {clause}
+ """
+
+ txn.execute(sql, args)
+
+ return txn.fetchall()
+
+ # room_id -> receipts
+ room_to_receipts: Dict[str, List[ReceiptInRoom]] = {}
+ for batch in batch_iter(room_and_event_ids, 1000):
+ batch_results = await self.db_pool.runInteraction(
+ "get_linearized_receipts_for_events",
+ get_linearized_receipts_for_events_txn,
+ batch,
+ )
+
+ for (
+ room_id,
+ receipt_type,
+ user_id,
+ event_id,
+ thread_id,
+ data,
+ ) in batch_results:
+ room_to_receipts.setdefault(room_id, []).append(
+ ReceiptInRoom(
+ receipt_type=receipt_type,
+ user_id=user_id,
+ event_id=event_id,
+ thread_id=thread_id,
+ data=db_to_json(data),
+ )
+ )
+
+ return room_to_receipts
+
@cached(
num_args=2,
)
@@ -550,6 +665,114 @@ class ReceiptsWorkerStore(SQLBaseStore):
return results
+ async def get_linearized_receipts_for_user_in_rooms(
+ self, user_id: str, room_ids: StrCollection, to_key: MultiWriterStreamToken
+ ) -> Mapping[str, Sequence[ReceiptInRoom]]:
+ """Fetch all receipts for the user in the given room.
+
+ Returns:
+ A dict from room ID to receipts in the room.
+ """
+
+ def get_linearized_receipts_for_user_in_rooms_txn(
+ txn: LoggingTransaction,
+ batch_room_ids: StrCollection,
+ ) -> List[Tuple[str, str, str, str, Optional[str], str]]:
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", batch_room_ids
+ )
+
+ sql = f"""
+ SELECT instance_name, stream_id, room_id, receipt_type, user_id, event_id, thread_id, data
+ FROM receipts_linearized
+ WHERE {clause} AND user_id = ? AND stream_id <= ?
+ """
+
+ args.append(user_id)
+ args.append(to_key.get_max_stream_pos())
+
+ txn.execute(sql, args)
+
+ return [
+ (room_id, receipt_type, user_id, event_id, thread_id, data)
+ for instance_name, stream_id, room_id, receipt_type, user_id, event_id, thread_id, data in txn
+ if MultiWriterStreamToken.is_stream_position_in_range(
+ low=None,
+ high=to_key,
+ instance_name=instance_name,
+ pos=stream_id,
+ )
+ ]
+
+ # room_id -> receipts
+ room_to_receipts: Dict[str, List[ReceiptInRoom]] = {}
+ for batch in batch_iter(room_ids, 1000):
+ batch_results = await self.db_pool.runInteraction(
+ "get_linearized_receipts_for_events",
+ get_linearized_receipts_for_user_in_rooms_txn,
+ batch,
+ )
+
+ for (
+ room_id,
+ receipt_type,
+ user_id,
+ event_id,
+ thread_id,
+ data,
+ ) in batch_results:
+ room_to_receipts.setdefault(room_id, []).append(
+ ReceiptInRoom(
+ receipt_type=receipt_type,
+ user_id=user_id,
+ event_id=event_id,
+ thread_id=thread_id,
+ data=db_to_json(data),
+ )
+ )
+
+ return room_to_receipts
+
+ async def get_rooms_with_receipts_between(
+ self,
+ room_ids: StrCollection,
+ from_key: MultiWriterStreamToken,
+ to_key: MultiWriterStreamToken,
+ ) -> StrCollection:
+ """Given a set of room_ids, find out which ones (may) have receipts
+ between the two tokens (> `from_token` and <= `to_token`)."""
+
+ room_ids = self._receipts_stream_cache.get_entities_changed(
+ room_ids, from_key.stream
+ )
+ if not room_ids:
+ return []
+
+ def f(txn: LoggingTransaction, room_ids: StrCollection) -> StrCollection:
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", room_ids
+ )
+
+ sql = f"""
+ SELECT DISTINCT room_id FROM receipts_linearized
+ WHERE {clause} AND ? < stream_id AND stream_id <= ?
+ """
+ args.append(from_key.stream)
+ args.append(to_key.get_max_stream_pos())
+
+ txn.execute(sql, args)
+
+ return [room_id for (room_id,) in txn]
+
+ results: List[str] = []
+ for batch in batch_iter(room_ids, 1000):
+ batch_result = await self.db_pool.runInteraction(
+ "get_rooms_with_receipts_between", f, batch
+ )
+ results.extend(batch_result)
+
+ return results
+
async def get_users_sent_receipts_between(
self, last_id: int, current_id: int
) -> List[str]:
@@ -807,9 +1030,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
SELECT event_id WHERE room_id = ? AND stream_ordering IN (
SELECT max(stream_ordering) WHERE %s
)
- """ % (
- clause,
- )
+ """ % (clause,)
txn.execute(sql, [room_id] + list(args))
rows = txn.fetchall()
@@ -954,6 +1175,12 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME,
self._background_receipts_graph_unique_index,
)
+ self.db_pool.updates.register_background_index_update(
+ update_name="receipts_room_id_event_id_index",
+ index_name="receipts_linearized_event_id",
+ table="receipts_linearized",
+ columns=("room_id", "event_id"),
+ )
async def _populate_receipt_event_stream_ordering(
self, progress: JsonDict, batch_size: int
|